## 使用 nn.BCEWithLogitsLoss() 作为loss， 对于label为-1的值，不计算loss
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import ipdb
import matplotlib.pyplot as plt

class DQNCOSLoss(nn.Module):
    def __init__(self):
        super(DQNCOSLoss, self).__init__()

    def forward(self, input):
        batch_size = input.size(0)
        target = Variable(torch.LongTensor(range(batch_size))).to(input.device)
        loss = 0
        loss += nn.CrossEntropyLoss()(input, target)
        loss += nn.CrossEntropyLoss()(input.transpose(1, 0), target)
        return loss / 2

class DQNCOS_label_classify_loss(nn.Module):
    def __init__(self):
        super(DQNCOS_label_classify_loss, self).__init__()
        self.bce_loss = nn.BCEWithLogitsLoss()

    def generate_output_tensor(self, input_list):
        output_tensor = torch.zeros(24, device='cuda')
        for item in input_list:
            if '25' in item or '26' in item:
                continue
            if ',' in item:
                arr = item.split(',')
                for a in arr:
                    if a.endswith('+'):
                        index = int(a[:-1]) - 1
                        output_tensor[index] = 1
            else:
                if item.endswith('+'):
                    index = int(item[:-1]) - 1
                    output_tensor[index] = 1
        return output_tensor

    def forward(self, label_list, cls_label):
        bs = cls_label.size(0)
        label_tensors = torch.stack([self.generate_output_tensor(label_list[i]) for i in range(bs)])
        loss = self.bce_loss(cls_label, label_tensors)
        return loss * 10

class DQNCOS_label_plus_cl_loss(nn.Module):
    def __init__(self):
        super(DQNCOS_label_plus_cl_loss, self).__init__()
        self.bce_loss = nn.BCEWithLogitsLoss()
        self.ce_loss = nn.CrossEntropyLoss()

    def merge_and_deduplicate(self, input_list):
        new_list = []
        for list_sample in input_list:
            split_elements = set()
            for item in list_sample:
                elements = str(item).split(', ')
                for element in elements:
                    split_elements.add(element)
            new_list.append(list(split_elements))
        return new_list

    def creat_label(self, label_sample, label_list):
        target_p, target_0, target_n = [], [], []
        label_list = self.merge_and_deduplicate(label_list)
        mapping = {1: (1, 0, 0), 0: (0, 1, 0), -1: (0, 0, 1)}
        for list in label_list:
            row_list_p, row_list_0, row_list_n = [], [], []
            for sample in label_sample:
                if '25' not in sample and '26' not in sample:
                    label_temp = 1
                    elements = str(sample).split(', ')
                    for item in elements:
                        number = item[:-1]
                        sign = item[-1]
                        if sign == '-' and number + '+' not in list:
                            pass
                        elif sign == '+' and number + '+' in list:
                            pass
                        elif (sign == '+' and number + '-' in list) or (sign == '-' and number + '+' in list):
                            label_temp = -1
                        else:
                            if label_temp != -1:
                                label_temp = 0
                    row_list_p.append(mapping[label_temp][0])
                    row_list_0.append(mapping[label_temp][1])
                    row_list_n.append(mapping[label_temp][2])
                else:
                    row_list_p.append(0)
                    row_list_0.append(1)
                    row_list_n.append(0)
            target_p.append(row_list_p)
            target_0.append(row_list_0)
            target_n.append(row_list_n)
        for i in range(len(target_p)):
            target_p[i][i] = 1
            target_0[i][i] = 0
            target_n[i][i] = 0
        target = np.stack((target_p, target_0, target_n), axis=-1)
        return target

    def forward(self, input, label_sample, label_list):
        target = torch.tensor(self.creat_label(label_sample, label_list)).float().to('cuda')
        loss = 0

        # # bce loss
        # for i in range(input.size(-1)):
        #     loss += self.bce_loss(input[:, :, i], target[:, :, i])

        # ce loss
        input_flat = input.reshape(-1, 3)
        target_flat = target.reshape(-1, 3)
        target_indices = torch.argmax(target_flat, dim=1)
        loss += self.ce_loss(input_flat, target_indices)

        return loss * 5

class sent_label_plus_with_cl_loss(nn.Module):
    def __init__(self):
        super(sent_label_plus_with_cl_loss, self).__init__()
        self.bce_loss = nn.BCEWithLogitsLoss()
        self.ce_loss = nn.CrossEntropyLoss()

    def merge_and_deduplicate(self, input_list):
        new_list = []
        for list_sample in input_list:
            split_elements = set()
            for item in list_sample:
                elements = str(item).split(', ')
                for element in elements:
                    split_elements.add(element)
            new_list.append(list(split_elements))
        return new_list

    def creat_label(self, label_sample, label_list):
        target_p, target_0, target_n = [], [], []
        label_list = self.merge_and_deduplicate(label_list)
        mapping = {1: (1, 0, 0), 0: (0, 1, 0), -1: (0, 0, 1)}
        for list in label_list:
            row_list_p, row_list_0, row_list_n = [], [], []
            for sample in label_sample:
                if '25' not in sample and '26' not in sample:
                    label_temp = 1
                    elements = str(sample).split(', ')
                    for item in elements:
                        number = item[:-1]
                        sign = item[-1]
                        if sign == '-' and number + '+' not in list:
                            pass
                        elif sign == '+' and number + '+' in list:
                            pass
                        elif (sign == '+' and number + '-' in list) or (sign == '-' and number + '+' in list):
                            label_temp = -1
                        else:
                            if label_temp != -1:
                                label_temp = 0
                    row_list_p.append(mapping[label_temp][0])
                    row_list_0.append(mapping[label_temp][1])
                    row_list_n.append(mapping[label_temp][2])
                else:
                    row_list_p.append(0)
                    row_list_0.append(1)
                    row_list_n.append(0)
            target_p.append(row_list_p)
            target_0.append(row_list_0)
            target_n.append(row_list_n)
        for i in range(len(target_p)):
            target_p[i][i] = 1
            target_0[i][i] = 0
            target_n[i][i] = 0
        target = np.stack((target_p, target_0, target_n), axis=-1)
        return target

    def forward(self, input, label_sample, label_list):
        input_1 = input[0]
        input_3 = input[1]
        target_3 = torch.tensor(self.creat_label(label_sample, label_list)).float().to('cuda')
        loss = 0

        # # bce loss
        # for i in range(input.size(-1)):
        #     loss += self.bce_loss(input[:, :, i], target[:, :, i])

        # ce loss
        input_flat = input_3.reshape(-1, 3)
        target_flat = target_3.reshape(-1, 3)
        target_indices = torch.argmax(target_flat, dim=1)
        loss += self.ce_loss(input_flat, target_indices) * 10

        target_1 = Variable(torch.LongTensor(range(len(input_1)))).to(input_1.device)
        loss += nn.CrossEntropyLoss()(input_1, target_1)
        loss += nn.CrossEntropyLoss()(input_1.transpose(1, 0), target_1)

        return loss * 0.25

class CL_sent_label_plus_loss(nn.Module):
    def __init__(self):
        super(CL_sent_label_plus_loss, self).__init__()
        self.ce_loss = nn.CrossEntropyLoss()

    def merge_and_deduplicate(self, input_list):
        new_list = []
        for list_sample in input_list:
            split_elements = set()
            for item in list_sample:
                elements = str(item).split(', ')
                for element in elements:
                    split_elements.add(element)
            new_list.append(list(split_elements))
        return new_list

    def creat_label(self, label_sample, label_list):
        target_p, target_0, target_n = [], [], []
        label_list = self.merge_and_deduplicate(label_list)
        mapping = {1: (1, 0, 0), 0: (0, 1, 0), -1: (0, 0, 1)}
        for list in label_list:
            row_list_p, row_list_0, row_list_n = [], [], []
            for sample in label_sample:
                if '25' not in sample and '26' not in sample:
                    label_temp = 1
                    elements = str(sample).split(', ')
                    for item in elements:
                        number = item[:-1]
                        sign = item[-1]
                        if sign == '-' and number + '+' not in list:
                            pass
                        elif sign == '+' and number + '+' in list:
                            pass
                        elif (sign == '+' and number + '-' in list) or (sign == '-' and number + '+' in list):
                            label_temp = -1
                        else:
                            if label_temp != -1:
                                label_temp = 0
                    row_list_p.append(mapping[label_temp][0])
                    row_list_0.append(mapping[label_temp][1])
                    row_list_n.append(mapping[label_temp][2])
                else:
                    row_list_p.append(0)
                    row_list_0.append(1)
                    row_list_n.append(0)
            target_p.append(row_list_p)
            target_0.append(row_list_0)
            target_n.append(row_list_n)
        for i in range(len(target_p)):
            target_p[i][i] = 1
            target_0[i][i] = 0
            target_n[i][i] = 0
        target = np.stack((target_p, target_0, target_n), axis=-1)
        return target

    def soft_cross_entropy_loss(self, input, target):
        logprobs = torch.nn.functional.log_softmax(input, dim=1)
        nan_mask = torch.isnan(target)
        valid_target = target[~nan_mask]
        valid_logprobs = logprobs[~nan_mask]

        cross_entropy_loss = 0
        mse_loss = 0
        if len(valid_target) != 0:
            cross_entropy_loss = -(valid_target * valid_logprobs).sum() / (valid_target.shape[0] / target.shape[0])

        # if nan_mask.sum().item() > 0:
        #     neg_input = torch.exp(input[nan_mask])
        #     neg_target = torch.zeros(neg_input.shape).to(neg_input.device)
        #     mse_loss = F.mse_loss(neg_input, neg_target)
        return cross_entropy_loss + mse_loss

    def soft_infoNCE_loss(self, logits_per_img, soft_label):
        image_loss = self.soft_cross_entropy_loss(logits_per_img, soft_label / soft_label.sum(dim=1).unsqueeze(1))
        caption_loss = self.soft_cross_entropy_loss(logits_per_img.T, soft_label.T / soft_label.T.sum(dim=1).unsqueeze(1))
        return (image_loss + caption_loss) / 2

    def forward(self, input, label_sample, label_list):

        target = torch.tensor(self.creat_label(label_sample, label_list)).float().to('cuda')

        slices = {'p': 0, '0': 1, 'n': 2}
        inputs = {key: input[:, :, idx] for key, idx in slices.items()}
        targets = {key: target[:, :, idx] for key, idx in slices.items()}

        loss = 0

        # soft infoNCE loss
        loss += self.soft_infoNCE_loss(inputs['p'], targets['p'])
        loss += self.soft_infoNCE_loss(inputs['0'], targets['0'])
        loss += self.soft_infoNCE_loss(inputs['n'], targets['n'])
        # # ce loss
        # input_flat = input.reshape(-1, 3)
        # target_flat = input.reshape(-1, 3)
        # target_indices = torch.argmax(target_flat, dim=1)
        # loss += self.ce_loss(input_flat, target_indices) * 5

        # # oral ce loss
        # target_ce = Variable(torch.LongTensor(range(len(input)))).to(input.device)
        # loss += nn.CrossEntropyLoss()(inputs['p'], target_ce)
        # loss += nn.CrossEntropyLoss()(inputs['p'].transpose(1, 0), target_ce)

        return loss * 0.5

class DQNCOS_global_sent_label_loss(nn.Module):
    def __init__(self):
        super(DQNCOS_global_sent_label_loss, self).__init__()

    def merge_and_deduplicate(self, input_list):
        new_list = []
        for list_sample in input_list:
            split_elements = set()
            for item in list_sample:
                elements = str(item).split(', ')
                for element in elements:
                    split_elements.add(element)
            new_list.append(list(split_elements))
        return new_list

    def process_array(self, input_list):
        output_list = []
        for sample in input_list:
            result = [0] * 24
            for item in sample:
                if '+' in item:
                    index = int(item[:-1]) - 1  # 提取数字并转化为数组索引
                    if 0 <= index < 24:
                        result[index] = 1
                elif '-' in item:
                    index = int(item[:-1]) - 1  # 提取数字并转化为数组索引
                    if 0 <= index < 24:
                        result[index] = -1
            output_list.append(result)
        return output_list

    def creat_label(self, label_sample, label_list):
        target = []
        label_list = self.merge_and_deduplicate(label_list)
        label_list = self.process_array(label_list)
        for list in label_list:
            row_list = []
            for sample in label_sample:
                sample = str(sample).split(', ')
                result = [0] * 24
                for sample_temp in sample:
                    if '+' in sample_temp:
                        index = int(sample_temp[:-1]) - 1  # 提取数字并转化为数组索引
                        if 0 <= index < 24:
                            result[index] = 1
                    elif '-' in sample_temp:
                        index = int(sample_temp[:-1]) - 1  # 提取数字并转化为数组索引
                        if 0 <= index < 24:
                            result[index] = -1
                dot_product = sum(x * y for x, y in zip(list, result))
                row_list.append(dot_product)
            target.append(row_list)
        return target

    def _soft_clip_loss(self, logits_per_img, soft_label):
        image_loss = self._soft_xent_loss(logits_per_img, F.softmax(soft_label, 1))
        caption_loss = self._soft_xent_loss(logits_per_img.T, F.softmax(soft_label.T, 1))
        return (image_loss + caption_loss) / 2

    def _soft_xent_loss(self, input, target):
        logprobs = torch.nn.functional.log_softmax(input, dim=1)
        return -(target * logprobs).sum() / input.shape[0]

    def forward(self, global_input, label_sample, label_list):
        target = torch.tensor(self.creat_label(label_sample, label_list)).float().to('cuda')
        loss = self._soft_clip_loss(global_input, target)
        return loss / 40

class DQNCOS_label_cl_loss(nn.Module):
    def __init__(self):
        super(DQNCOS_label_cl_loss, self).__init__()
        self.bce_loss = nn.BCEWithLogitsLoss()

    def merge_and_deduplicate(self, input_list):
        new_list = []
        for list_sample in input_list:
            split_elements = set()
            for item in list_sample:
                elements = str(item).split(', ')
                for element in elements:
                    split_elements.add(element)
            new_list.append(list(split_elements))
        return new_list

    def creat_label(self, label_sample, label_list):
        target = []
        label_list = self.merge_and_deduplicate(label_list)
        for list in label_list:
            row_list = []
            for sample in label_sample:
                if '25' not in sample and '26' not in sample:
                    label_temp = 1
                    elements = str(sample).split(', ')
                    for item in elements:
                        number = item[:-1]
                        sign = item[-1]
                        if sign == '-' and number + '+' not in list:
                            pass
                        elif sign == '+' and number + '+' in list:
                            pass
                        else:
                            label_temp = 0
                    row_list.append(label_temp)
                else:
                    row_list.append(0)
            target.append(row_list)
        for i in range(len(target)):
            target[i][i] = 1
        return target

    # Visualization
    # target_visualize = target.detach().cpu().numpy()
    # plt.imshow(target_visualize, cmap='viridis')
    # plt.colorbar()
    # plt.title('Tensor Visualization')
    # plt.show()

    def soft_cross_entropy_loss(self, input, target):
        logprobs = torch.nn.functional.log_softmax(input, dim=1)
        return -(target * logprobs).sum() / input.shape[0]

    def soft_infoNCE_loss(self, logits_per_img, soft_label):
        image_loss = self.soft_cross_entropy_loss(logits_per_img, soft_label / soft_label.sum(dim=1).unsqueeze(1))
        caption_loss = self.soft_cross_entropy_loss(logits_per_img.T, soft_label.T / soft_label.T.sum(dim=1).unsqueeze(1))
        return (image_loss + caption_loss) / 2

    def forward(self, input, label_sample, label_list):

        target = torch.tensor(self.creat_label(label_sample, label_list)).float().to('cuda')

        loss = 0

        # # bce loss
        # loss += self.bce_loss(input, target) * 8

        # soft infoNCE loss
        loss += self.soft_infoNCE_loss(input, target)

        return loss

class MedCLIP_loss(nn.Module):
    def __init__(self):
        super(MedCLIP_loss, self).__init__()

    def merge_and_deduplicate(self, input_list):
        new_list = []
        for list_sample in input_list:
            split_elements = set()
            for item in list_sample:
                elements = str(item).split(', ')
                for element in elements:
                    split_elements.add(element)
            new_list.append(list(split_elements))
        return new_list

    def process_array(self, input_list):
        output_list = []
        for sample in input_list:
            result = [0] * 24
            for item in sample:
                if '25' in str(item) or '26' in str(item):
                    continue
                if '+' in item:
                    index = int(item[:-1]) - 1  # 提取数字并转化为数组索引
                    if 0 <= index < 24:
                        result[index] = 1
                elif '-' in item:
                    index = int(item[:-1]) - 1  # 提取数字并转化为数组索引
                    if 0 <= index < 24:
                        result[index] = -1
            output_list.append(result)
        return output_list

    def creat_label(self, label_list):
        label_list = self.merge_and_deduplicate(label_list)
        label_list = self.process_array(label_list)
        M = torch.tensor(label_list, dtype=torch.float32)
        result = torch.matmul(M, M.T)
        return result

    def _soft_clip_loss(self, logits_per_img, soft_label):
        image_loss = self._soft_xent_loss(logits_per_img, F.softmax(soft_label, 1))
        caption_loss = self._soft_xent_loss(logits_per_img.T, F.softmax(soft_label.T, 1))
        return (image_loss + caption_loss) / 2

    def _soft_xent_loss(self, input, target):
        logprobs = torch.nn.functional.log_softmax(input, dim=1)
        return -(target * logprobs).sum() / input.shape[0]

    def forward(self, global_input, label_list):
        target = self.creat_label(label_list).to('cuda')
        loss = self._soft_clip_loss(global_input, target)
        return loss