Keras create custom loss function YOLO v1

I am trying to implement a custom loss function from the yolo v1 paper. I am encountering some Nan value when I use this loss. I decomposed the loss function into 3 loss functions. Two of them return Nan. I was wondering if I a managed well the tensor. ( It is the first time a create a true loss function). I tried to consider the tensor like a numpy array but using the keras backend function. Is it a good idea ?

1) the loss about the bbox seems false. The error seems to come from the K.root. I check the y_true there is no negative value. The y pred is normally above or equal to 0 because of the relu activation function. So I do not understand why I obtained NaN after the second call of train_on_batch.

2) the loss regarding the confidence. I do not know where does the problem comes from. Is there a way to print the value so as to debug the loss function ?

YOLO paper : https://arxiv.org/pdf/1506.02640.pdf

Loss :

enter image description here

Here my custom loss :

def relu_clip(max_value=1.):        
    def relu_custom(x):
        return K.relu(x, max_value=K.cast_to_floatx(max_value))
    return relu_custom

def loss_yolo(self, y_true, y_pred):  
    '''
        dense layer : Sx * Sy * B * ((5) + C) 
        bbox : Sx * Sy * B * 4
        confidence = Sx * Sy * B * 1
        class : Sx * Sy * B * C

        7*7*4 = 196
        7*7*1 = 49
        7*7*20 = 980
    '''
    # reshape into cell
    y_true_bbox = K.reshape(y_true[:, :self.Sx*self.Sy*4*self.B], (-1, self.Sy, self.Sx, self.B, 4))
    y_pred_bbox = K.reshape(y_pred[:, :self.Sx*self.Sy*4*self.B], (-1, self.Sy, self.Sx, self.B, 4))
    y_true_confidence = K.reshape(y_true[:, self.Sx*self.Sy*4*self.B:self.Sx*self.Sy*5*self.B], (-1, self.Sy, self.Sx, self.B))
    y_pred_confidence = K.reshape(y_pred[:, self.Sx*self.Sy*4*self.B:self.Sx*self.Sy*5*self.B], (-1, self.Sy, self.Sx, self.B))
    y_true_class = K.reshape(y_true[:, self.Sx*self.Sy*5*self.B:], (-1, self.Sy, self.Sx, self.C))
    y_pred_class = K.reshape(y_pred[:, self.Sx*self.Sy*5*self.B:], (-1, self.Sy, self.Sx, self.C))

    # keep only boxes which exist in the dataset, if not put 0
    y_pred_bbox = y_pred_bbox * K.cast((y_true_bbox > 0), dtype='float32')

    # compute loss bbox
    loss_bbox = K.reshape(K.square( y_true_bbox[:,:,:,:,0:2] - y_pred_bbox[:,:,:,:,0:2]), (-1, self.Sx*self.Sy*2*self.B)) + K.reshape(K.square( K.sqrt(y_true_bbox[:, :, :, :, 2:]) - K.sqrt(y_pred_bbox[:, :, :, :, 2:])), (-1, self.Sx*self.Sy*2*self.B))
    loss_bbox = K.sum(loss_bbox, axis=1)

    # compute loss confidence
    xmin_true = y_true_bbox[:,:,:,:, 0] - y_true_bbox[:,:,:,:, 2]
    ymin_true = y_true_bbox[:,:,:,:, 1] - y_true_bbox[:,:,:,:, 3]
    xmax_true = y_true_bbox[:,:,:,:, 0] + y_true_bbox[:,:,:,:, 2]
    ymax_true = y_true_bbox[:,:,:,:, 1] + y_true_bbox[:,:,:,:, 3]

    xmin_pred = y_pred_bbox[:,:,:,:, 0] - y_pred_bbox[:,:,:,:, 2]
    ymin_pred = y_pred_bbox[:,:,:,:, 1] - y_pred_bbox[:,:,:,:, 3]
    xmax_pred = y_pred_bbox[:,:,:,:, 0] + y_pred_bbox[:,:,:,:, 2]
    ymax_pred = y_pred_bbox[:,:,:,:, 1] + y_pred_bbox[:,:,:,:, 3]

    print(' Xmin true : ', K.int_shape(xmin_true))

    xA = K.maximum(xmin_true, xmin_pred)
    yA = K.maximum(ymin_true, ymin_pred)
    xB = K.minimum(xmax_true, xmax_pred)
    yB = K.minimum(ymax_true, ymax_pred)
    print('Xa : ', K.int_shape(xA))
    #if xA < xB and yA < yB:
    condition1 = K.cast((xA<xB), dtype='float32')
    condition2 =  K.cast( (yA<yB), dtype='float32')
    condition = condition1 + condition2

    # find which iou to compute
    tocompute = K.cast( K.equal(condition, 2.0), dtype='float32')

        # compute the area of intersection rectangle
    interArea = (xB - xA) * (yB - yA) * tocompute
    # compute the area of both the prediction and ground-truth
    # rectangles
    boxAArea = tocompute * (xmax_pred - xmin_pred) * (ymax_pred - ymin_pred) 
    boxBArea = tocompute * (xmax_true - xmin_true) * (ymax_true - ymin_true)
    # compute the intersection over union by taking the intersection
    # area and dividing it by the sum of prediction + ground-truth
    # areas - the intersection area
    iou = (interArea / (boxAArea + boxBArea - interArea)) * y_true_confidence * y_pred_confidence
    print('iou shape : ', K.int_shape(iou))
    loss_confidence = K.reshape( K.square(iou - y_true_confidence), (-1, self.Sy*self.Sx*self.B))
    loss_confidence = K.sum(loss_confidence, axis=1)
    print('loss confidence shape :', K.int_shape(loss_confidence))


    # keep only prediction class if there is an object in the cell, else put class to 0
    y_pred_class = (K.reshape(y_true_confidence[:,:,:,0], (-1, self.Sy, self.Sx, 1)) * y_pred_class)

    # compute loss class
    loss_class = K.sum(K.square( y_pred_class  - y_true_class), axis=3)
    loss_class = K.sum(loss_class, axis=2)
    loss_class = K.sum(loss_class, axis=1)
    print(K.int_shape(loss_bbox))
    print(K.int_shape(loss_class))

    #loss = K.mean(loss_class + self.lambda_coord*loss_bbox + self.lambda_noobj * loss_confidence)
    loss = K.mean(loss_bbox) #K.mean(loss_confidence) #K.mean(self.lambda_noobj * loss_confidence)
    return loss