import torch.nn as nn
import torch
import torch.nn.functional as F
import math
#for multiview   
class CrossEntropyLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.criterion_multi_cls = nn.CrossEntropyLoss(reduction='none')
        self.criterion_binary_cls = nn.BCELoss(reduction='none')
        self.BCEWithLogits_Loss =  nn.BCEWithLogitsLoss(reduction='none')
    
    def forward(self, outputs, targets):
        total_loss = 0.0
        total_count = 0
        
        loss = self.BCEWithLogits_Loss(outputs[f'laban_0'],targets[f'annotation_laban_0'])
        loss = loss * targets[f'laban_0_mask'].float()
        total_loss = torch.sum(loss) 
        total_count = torch.sum(targets[f'laban_0_mask'])         
        return total_loss/total_count

def decode_horizontal(indices):
    mapping = {0: 'Place', 1: 'Forward', 2: 'Left Forward', 3: 'Left', 4: 'Left Backward',
               5: 'Backward', 6: 'Right Backward', 7: 'Right', 8: 'Right Forward', 9: 'Place', 10: 'Forward',
               11: 'Left Forward', 12: 'Left', 13: 'Left Backward', 14: 'Backward', 15: 'Right Backward',
               16: 'Right', 17: 'Right Forward', 18: 'Forward', 19: 'Left Forward', 20: 'Left', 
               21: 'Left Backward', 22: 'Backward', 23: 'Right Backward', 24: 'Right', 25: 'Right Forward'}
    return [mapping[i.item()] for i in indices]

def decode_vertical(indices):
    mapping = {0: 'Low', 1: 'Low', 2: 'Low', 3: 'Low', 4: 'Low', 5: 'Low', 6: 'Low', 7: 'Low', 8: 'Low', 
               9: 'High', 10: 'High', 11: 'High', 12: 'High', 13: 'High', 14: 'High', 15: 'High', 16: 'High',
               17: 'High', 18: 'Normal', 19: 'Normal', 20: 'Normal', 21: 'Normal', 22: 'Normal', 23: 'Normal',
               24: 'Normal', 25: 'Normal'}
    return [mapping[i.item()] for i in indices]
def idx_coder(indices):
    mapping = {0: '0', 1: '0', 2: '0', 3: '0', 4: '0', 5: '0', 6: '0', 7: '0', 8: '0', 
               9: '1', 10: '1', 11: '1', 12: '1', 13: '1', 14: '1', 15: '1', 16: '1',
               17: '1', 18: '2', 19: '2', 20: '2', 21: '2', 22: '2', 23: '2',
               24: '2', 25: '2'}
    return [mapping[i.item()] for i in indices] 
class LabanBiasedloss(nn.Module):
    def __init__(self):
        super(LabanBiasedloss, self).__init__()
        self.criterion_multi_cls = nn.MSELoss(reduction='none')
        self.criterion_binary_cls = nn.BCELoss(reduction='none')
        self.BCEWithLogits_Loss =  nn.BCEWithLogitsLoss(reduction='none')
    
    def forward(self, outputs, targets):    
        total_loss_1 = 0.0
        total_loss_2 = 0.0
        total_count = 0
        horizontal_loss = 0.0
        vertical_loss = 0.0
        for i in range(len(outputs)):
            output = outputs[f'laban_{i}']
            target = targets[f'annotation_laban_{i}']
            mask = targets[f'laban_{i}_mask']
            
            preds_idx = torch.argmax(output, dim = 2).int()
            preds_flat = preds_idx.view(-1)
            targets_idx = torch.argmax(target, dim = 2).int()
            targets_flat = targets_idx.view(-1)
            mask_idx = torch.max(mask, dim = 2)[0].int()
            mask_flat = mask_idx.view(-1)
            
            valid_index = mask_flat == 1
            preds_h = decode_horizontal(preds_flat[valid_index])
            targets_h = decode_horizontal(targets_flat[valid_index])
            distance_h = sum(p != t for p, t in zip(preds_h, targets_h)) / len(targets_h)
            distance_h_m = [p !=t for p, t in zip(preds_h, targets_h)]
            
            preds_v = decode_vertical(preds_flat[valid_index])
            targets_v = decode_vertical(targets_flat[valid_index])
            distance_v = sum(p != t for p, t in zip(preds_v, targets_v)) / len(targets_v)
            distance_v_m = [p !=t for p, t in zip(preds_v, targets_v)]
            
            distance_all = sum(p == t for p, t in zip(distance_h_m, distance_v_m)) / len(distance_h_m)
            
            horizontal_loss += distance_h
            vertical_loss += distance_v
            total_loss_1 += distance_all
             
            loss = self.BCEWithLogits_Loss(outputs[f'laban_{i}'],targets[f'annotation_laban_{i}'])
            loss = loss * targets[f'laban_{i}_mask'].float()
            total_loss_2 += torch.sum(loss) 
            total_count += torch.sum(targets[f'laban_{i}_mask'])
          
        final_loss = (0.5*(total_loss_1/len(outputs)) + 0.3*(horizontal_loss/len(outputs)) + 0.2*(vertical_loss/len(outputs))) + (total_loss_2/total_count)
        return final_loss
    
class AdaptiveWingLoss(nn.Module):
    def __init__(self,
                 alpha=2.1,
                 omega=14,
                 epsilon=1,
                 theta=0.5,
                 use_target_weight=True,
                 loss_weight=1.):
        super(AdaptiveWingLoss, self).__init__()
        self.alpha = float(alpha)
        self.omega = float(omega)
        self.epsilon = float(epsilon)
        self.theta = float(theta)
        self.use_target_weight = use_target_weight
        self.loss_weight = loss_weight

    def criterion(self, pred, target):
        
        delta = (target - pred).abs()

        A = self.omega * (
            1 / (1 + torch.pow(self.theta / self.epsilon, self.alpha - target))
        ) * (self.alpha - target) * (torch.pow(
            self.theta / self.epsilon,
            self.alpha - target - 1)) * (1 / self.epsilon)
        C = self.theta * A - self.omega * torch.log(
            1 + torch.pow(self.theta / self.epsilon, self.alpha - target))

        losses = torch.where(
            delta < self.theta,
            self.omega *
            torch.log(1 +
                      torch.pow(delta / self.epsilon, self.alpha - target)),
            A * delta - C)

        return torch.mean(losses)

    def forward(self,
                outputs,
                targets):
        loss = 0.0
        for i in range(len(outputs)):
            output = outputs[f'laban_{i}']
            target = targets[f'annotation_laban_{i}']
            target_weights = targets[f'laban_{i}_mask']
            
            if self.use_target_weight:
                assert (target_weights.ndim in (2, 4) and target_weights.shape
                        == target.shape[:target_weights.ndim]), (
                            'target_weights and target have mismatched shapes '
                            f'{target_weights.shape} v.s. {target.shape}')

                ndim_pad = target.ndim - target_weights.ndim
                target_weights = target_weights.view(target_weights.shape +
                                                    (1, ) * ndim_pad)
                loss += self.criterion(output * target_weights,
                                    target * target_weights)
            else:
                loss += self.criterion(output, target)

        return loss * self.loss_weight/len(outputs)
    
class HirerarchicalCrossEntropyLoss(nn.Module):
    def __init__(self, alpha) -> None:
        super(HirerarchicalCrossEntropyLoss, self).__init__()
        self.BCEWithLogits_Loss =  nn.BCEWithLogitsLoss(reduction='none')
        self.alpha = alpha
        self.criterion_binary_cls = nn.BCEWithLogitsLoss()
    def forward(self, outputs, targets):
        delta = 0.001
        total_loss = 0.0
        total_count = 0
        output = outputs['laban_0']
        target = targets['annotation_laban_0']
        mask = targets['laban_0_mask']#shape: (batchsize, 40, 26)
        
        output = F.softmax(output, dim=2)
        high_outputs = output[:,:,9:18]
        low_outputs = output[:,:,:9]
        normal_outputs = output[:,:,18:]
        sum_high = torch.sum(high_outputs, dim=2)
        sum_low = torch.sum(low_outputs, dim=2)
        sum_normal = torch.sum(normal_outputs, dim=2)
        sum_list = torch.stack((sum_low, sum_high, sum_normal), dim=2)
        
        high_targets = targets['annotation_laban_0'][:,:,9:18]
        low_targets = targets['annotation_laban_0'][:,:,:9]
        normal_targets = targets['annotation_laban_0'][:,:,18:]
        
        mask_idx = torch.max(mask, dim=2)[0].int()#shape
        mask_flatten = mask_idx.view(-1)
        pcl = output * target
        pcl_idx = torch.argmax(pcl, dim=2).int()#shape: (batchsize,40)
        pcl_idx_flatten = pcl_idx.view(-1)#shape(batchsize*40) batchsize=4
        pcl_flag = idx_coder(pcl_idx_flatten)
        
        pcl_item = torch.sum(pcl, dim=2)#shape: (batchsize, 40)
        pcl_father = torch.zeros_like(pcl_item)#shape: (batchsize,40)
        
        for i in range(pcl_item.size(0)):
            for j in range(pcl_item.size(1)):
                flag = int(pcl_flag[i*pcl_item.size(1)+j])
                sum_test = sum_list[i][j][flag]
                pcl_father[i][j] = sum_list[i][j][flag]
                
        pcl_father_flatten = pcl_father.view(-1)
        pcl_item_flatten = pcl_item.view(-1)
        loss_h2 = -sum([(a/(b+delta)*c + delta).log() for a, b, c in zip(pcl_item_flatten, pcl_father_flatten, mask_flatten)])/sum(mask_flatten)
        # loss_h2 = -((pcl_flag/pcl_father_flatten)*mask_flatten).log().mean()
        loss_h1 = -sum([(a*b+delta).log() for a, b in zip(pcl_father_flatten, mask_flatten)])/sum(mask_flatten)
        # loss_h1 = -(pcl_father_flatten*mask_flatten).log().mean()
        loss = self.BCEWithLogits_Loss(outputs[f'laban_0'],targets[f'annotation_laban_0'])
        loss = loss * targets[f'laban_0_mask'].float()
        total_loss = torch.sum(loss) 
        total_count = torch.sum(targets[f'laban_0_mask'])
        loss_old = total_loss/total_count
        
        Loss = loss_h1 * math.exp(-self.alpha * 1) + loss_h2 * math.exp(-self.alpha * 2) + loss_old
        return Loss