from PIL import Image
import numpy as np
import shutil
import torch
import torch.nn.functional as F
import torch.nn as nn
import os
from torch.autograd import grad
from tqdm import tqdm
import cv2
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

def gradient_penalty(critic, h_s, h_t):
    ''' Gradeitnt penalty approach'''
    alpha = torch.rand(h_s.size(0), 1).cuda()
    differences = h_t - h_s
    interpolates = h_s + (alpha * differences)
    # interpolates = torch.cat([interpolates, h_s, h_t]).requires_grad_()
    interpolates.requires_grad_()
    preds = critic(interpolates)
    gradients = grad(preds, interpolates,
                     grad_outputs=torch.ones_like(preds),
                     retain_graph=True, create_graph=True)[0]
    gradient_norm = gradients.norm(2, dim=1)
    GP = ((gradient_norm - 1)**2).mean()

    return GP

def visualize(imgten, path, color=True, threshold = False, size=None, reverse = False, heatmap=False):
    imgten = imgten.detach().float()
    if color: # input should be [C,W,H]
        if imgten.size(0) == 3:
            if size != None:
                imgten = F.interpolate(imgten.unsqueeze(dim=0), size=size, mode='bilinear', align_corners=True)
                imgnp = imgten[0].detach().cpu().numpy().transpose([1, 2, 0])
            else:
                imgnp = imgten.permute([1,2,0]).cpu().numpy()
        elif len(imgten.shape) == 2:
            if size != None:
                imgten = F.interpolate(imgten[None,None,...], size=size, mode='bilinear', align_corners=True)
                imgnp = imgten[0].detach().cpu().numpy().transpose([1, 2, 0])
            else:
                imgnp = imgten.cpu().numpy()
        imgnp = np.interp(imgnp, (imgnp.min(), imgnp.max()), (0,255)).astype(np.uint8)
        if heatmap:
            imgnp = cv2.applyColorMap(imgnp, cv2.COLORMAP_JET)
            cv2.imwrite(path, imgnp)
            return
        img = Image.fromarray(imgnp)
        img.save(path)
    else: #grayscale, input should be [W,H]
        imgten = imgten.unsqueeze(dim=0).unsqueeze(dim=0).float()
        if size!= None:
            imgten = F.interpolate(imgten, size=size, mode='bilinear', align_corners=True)
        imgnp = imgten[0,0].detach().cpu().numpy()
        imgnp = np.interp(imgnp, (imgnp.min(), imgnp.max()), (0,255)).astype(np.uint8)
        if threshold:
            imgnp[imgnp<threshold] = 0; imgnp[imgnp>=threshold] = 255
        if reverse:
            imgnp = 255 - imgnp
        img = Image.fromarray(imgnp)
        img.save(path)

def DeleteContent(path):
    if not os.path.exists(path):
        os.mkdir(path)
    eval_list = os.listdir(path)
    for i in eval_list:
        if os.path.isdir(os.path.join(path,i)):
            shutil.rmtree(os.path.join(path,i))
        else:
            os.remove(os.path.join(path,i))

def FlipGT(GTs, H, W):
    Out = []
    for GT in GTs:
        out = []
        for gt in GT:
            _x1, _y1, _x2, _y2 = gt.tolist()
            out.append(torch.tensor([W-_x2, _y1, W-_x1, _y2]))
        Out.append(torch.stack(out).reshape(-1,4))
    return Out


def RotateGT(degree, GTs, H, W):
    Out = []
    for GT in GTs:
        out = []
        for gt in GT:
            _x1, _y1, _x2, _y2 = gt.tolist()
            if degree == 0:
                out.append(torch.tensor([_x1, _y1, _x2, _y2]))
            elif degree == 1:
                x1, y1, x2, y2 = [_y1, W-_x2, _y2, W-_x1]
                out.append(torch.tensor([x1, y1, x2, y2]))
            elif degree == 2:
                x1, y1, x2, y2 = [W-_x2, H-_y2, W-_x1, H-_y1]
                out.append(torch.tensor([x1, y1, x2, y2]))
            elif degree == 3:
                x1, y1, x2, y2 = [H - _y2, _x1, H - _y1, _x2]
                out.append(torch.tensor([x1, y1, x2, y2]))
        Out.append(torch.stack(out).reshape(-1,4))
    return Out

def RotateMeta(degree, img_metas):
    Out = []
    for img_meta in img_metas:
        tmp = {}
        for key in ['img_shape', 'pad_shape']:
            H,W,_ = img_meta[key]
            if degree == 0 or degree == 2:
                tmp[key] = (H,W,3)
            else:
                tmp[key] = (W,H,3)
        Out.append(tmp)
    return Out

def SaveCode(base_path, des_path):
    if (not os.path.isdir(base_path)):
        if '.py' in base_path:
            save_path = os.path.join(des_path, base_path.replace('../', ''))
            shutil.copy(base_path, save_path)
        return
    if 'wandb' in base_path: return
    filenames = os.listdir(base_path)
    for file in filenames:
        longPath = os.path.join(base_path, file)
        save_path = os.path.join(des_path, longPath.replace('../',''))
        if os.path.isdir(longPath):
            os.mkdir(save_path)
        SaveCode(longPath, des_path)

def RemoveEmptyDir(path):
    filenames = os.listdir(path)
    for file in filenames:
        longPath = os.path.join(path, file)
        if os.path.isdir(longPath):
            if len(os.listdir(longPath)) == 0:
                os.rmdir(longPath)
                continue
            RemoveEmptyDir(longPath)

def DrawGT(img, GT, path, GT_labels=None, cls_color=False):
    Colors = {0:(255,0,0), 1:(255,128,0), 2:(255,255,0), 3:(128,255,0), 4:(0,255,0),
              5:(0,255,128),6:(0,255,255),7:(0,128,255),8:(0,0,255),9:(128,0,255),
              10:(255,0,255),11:(255,0,128),12:(128,128,128),13:(255,255,255),14:(102,0,102),
              15:(123,113,52),16:(244,49,156),17:(79,32,154),18:(211,45,219),19:(105,50,177)}
    img = img.clone()
    width = 1
    for i, gt in enumerate(GT):
        x1,y1,x2,y2 = list(map(int, gt.tolist()))
        if x1 < width:
            x1 = width
        if y1 < width:
            y1 = width
        if GT_labels is not None:
            colors = {cls.cpu().item():idx+1 for idx,cls in enumerate(GT_labels.unique())}
            color = (colors[GT_labels[i].cpu().item()] / float(len(GT_labels.unique())))**2
        else:
            color = (0.8 ** i)
        img[:,y1:y2,x1-width:x1+width] = img.max() * color
        img[:,y1:y2,x2-width:x2+width] = img.max() * color
        img[:,y1-width:y1+width,x1:x2] = img.max() * color
        img[:,y2-width:y2+width,x1:x2] = img.max() * color
        if cls_color:
            r,g,b = Colors[GT_labels[i].item()]
            img[0, y1:y2, x1 - width:x1 + width] = img.max() * r / 255
            img[1, y1:y2, x1 - width:x1 + width] = img.max() * g / 255
            img[2, y1:y2, x1 - width:x1 + width] = img.max() * b / 255
            img[0, y1:y2, x2 - width:x2 + width] = img.max() * r / 255
            img[1, y1:y2, x2 - width:x2 + width] = img.max() * g / 255
            img[2, y1:y2, x2 - width:x2 + width] = img.max() * b / 255
            img[0, y1 - width:y1 + width, x1:x2] = img.max() * r / 255
            img[1, y1 - width:y1 + width, x1:x2] = img.max() * g / 255
            img[2, y1 - width:y1 + width, x1:x2] = img.max() * b / 255
            img[0, y2 - width:y2 + width, x1:x2] = img.max() * r / 255
            img[1, y2 - width:y2 + width, x1:x2] = img.max() * g / 255
            img[2, y2 - width:y2 + width, x1:x2] = img.max() * b / 255

    imgten = img.detach().float()
    imgnp = imgten.permute([1, 2, 0]).cpu().numpy()
    imgnp = np.interp(imgnp, (imgnp.min(), imgnp.max()), (0, 255)).astype(np.uint8)
    img = Image.fromarray(imgnp)
    img.save(path)

    for i, (gt, gt_label) in enumerate(zip(GT, GT_labels)):
        x1, y1, x2, y2 = list(map(int, gt.tolist()))
        cv2.putText(imgnp, str(i), (x1,y1), cv2.FONT_HERSHEY_SIMPLEX, 0.5, 255)

    img = Image.fromarray(imgnp)
    img.save(path[:-4] + '_number.jpg')
    a=2


def DrawGT_backup(img, GT, path, GT_labels=None):
    img = img.clone()
    width = 1
    for i, gt in enumerate(GT):
        x1, y1, x2, y2 = list(map(int, gt.tolist()))
        if x1 < width:
            x1 = width
        if y1 < width:
            y1 = width
        if GT_labels is not None:
            colors = {cls.cpu().item(): idx + 1 for idx, cls in enumerate(GT_labels.unique())}
            color = (colors[GT_labels[i].cpu().item()] / float(len(GT_labels.unique()))) ** 2
        else:
            color = (0.8 ** i)
        img[:, y1:y2, x1 - width:x1 + width] = img.max() * color
        img[:, y1:y2, x2 - width:x2 + width] = img.max() * color
        img[:, y1 - width:y1 + width, x1:x2] = img.max() * color
        img[:, y2 - width:y2 + width, x1:x2] = img.max() * color

    visualize(img, path)

def FuseList(list_of_tensor, selectFirst=False):
    totalLen = 0
    for _tensor in list_of_tensor:
        if selectFirst: totalLen += len(_tensor[0])
        else: totalLen += len(_tensor)

    firstTen = list_of_tensor[0]
    if selectFirst: firstTen = firstTen[0]
    if firstTen.dim() == 1:
        rst = firstTen.new_full((totalLen,), fill_value=0)
    else:
        rst = firstTen.new_full((totalLen,) + firstTen.size()[1:], fill_value=0)
    start ,end = 0, 0
    for _tensor in list_of_tensor:
        if selectFirst:
            end += len(_tensor[0])
            rst[start:end] = _tensor[0]
        else:
            end += len(_tensor)
            rst[start:end] = _tensor
        start = end
    return rst

def MakeWeights(featmap_size, pad_shape, sIdx, device):
    strides = [8, 16, 32, 64, 128]
    feat_h, feat_w = featmap_size
    h, w = pad_shape[:2]
    anchor_stride = strides[sIdx]
    valid_h = min(int(np.ceil(h / anchor_stride)), feat_h)
    valid_w = min(int(np.ceil(w / anchor_stride)), feat_w)
    assert valid_h <= feat_h and valid_w <= feat_w
    valid_x = torch.zeros(feat_w, dtype=torch.bool, device=device)
    valid_y = torch.zeros(feat_h, dtype=torch.bool, device=device)
    valid_x[:valid_w] = 1
    valid_y[:valid_h] = 1
    valid_xx, valid_yy = _meshgrid(valid_x, valid_y)
    valid = valid_xx & valid_yy
    num_base_anchors = 9
    valid = valid[:, None].expand(valid.size(0), num_base_anchors).contiguous().view(-1)
    return valid

def _meshgrid(x, y, row_major=True):
    xx = x.repeat(y.shape[0])
    yy = y.view(-1, 1).repeat(1, x.shape[0]).view(-1)
    if row_major:
        return xx, yy
    else:
        return yy, xx

def VisualizeData(data, model, cfg, **kwargs):
    torch.cuda.empty_cache()
    device = list(model.parameters())[0].device
    # Resolve mismatch from direct sampling to tset_pipeline
    data['img_metas'] = data['img_metas'].data
    data['gt_bboxes'] = data['gt_bboxes'].data
    data['gt_labels'] = data['gt_labels'].data
    if not kwargs['fromLoader']:
        data['img'] = data['img'].data.to(device)
        data['img'] = [data['img'][None,:]]
        data['img_metas'] = [[data['img_metas']]]
        data['gt_bboxes'] = [data['gt_bboxes']]
        data['gt_labels'] = [data['gt_labels']]
    else:
        data['img'] = data['img'].data
        data['img'][0] = data['img'][0].to(device)
    uType = cfg.uncertainty_type
    uPool = cfg.uncertainty_pool
    uPool2 = cfg.uncertainty_pool2
    model.eval()
    loss, *UncOuts = model(**data, return_loss=False, rescale=True, isEval=False, isUnc=uType, uPool=uPool,
                           uPool2 = uPool2, Labeled=True, Pseudo=False, draw=False, saveUnc=True, **kwargs)
    return (loss, *UncOuts)

dataset_type = 'VOCDataset'
data_root = '/drive1/YH/datasets/VOCdevkit/VOCdevkit/'
img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations', with_bbox=True),
    dict(type='Resize', img_scale=(1000, 600), keep_ratio=True),
    dict(type='RandomFlip', flip_ratio=0.5),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='Pad', size_divisor=32),
    dict(type='DefaultFormatBundle'),
    dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
]
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='MultiScaleFlipAug',
        img_scale=(1000, 600),
        flip=False,
        transforms=[
            dict(type='Resize', keep_ratio=True),
            dict(type='RandomFlip'),
            dict(type='Normalize', **img_norm_cfg),
            dict(type='Pad', size_divisor=32),
            dict(type='ImageToTensor', keys=['img']),
            dict(type='Collect', keys=['img']),
        ])
]

def ConfigDatasetAL(ismini=False):
    if ismini:
        return dict(
               type=dataset_type,
               ann_file=data_root + 'VOC2007/ImageSets/Main/mini_test.txt',
               img_prefix=data_root + 'VOC2007/',
               pipeline=train_pipeline)
    else:
        return dict(
        type=dataset_type,
        ann_file=[data_root + 'VOC2007/ImageSets/Main/trainval.txt',
                  data_root + 'VOC2012/ImageSets/Main/trainval.txt'],
        img_prefix=[data_root + 'VOC2007/', data_root + 'VOC2012/'],
        pipeline=train_pipeline)

def ConfigDatasetALCustom(custom=False):
    if custom:
        return dict(
           type=dataset_type,
           ann_file=[data_root + 'VOC2007/ImageSets/Main/custom_test.txt',
                     data_root + 'VOC2012/ImageSets/Main/custom_test.txt'],
           img_prefix=[data_root + 'VOC2007/', data_root + 'VOC2012/'],
           pipeline=train_pipeline)
    else:
        return dict(
        type=dataset_type,
        ann_file=[data_root + 'VOC2007/ImageSets/Main/trainval.txt',
                  data_root + 'VOC2012/ImageSets/Main/trainval.txt'],
        img_prefix=[data_root + 'VOC2007/', data_root + 'VOC2012/'],
        pipeline=train_pipeline)

def ConfigDatasetTEST(ismini=False):
    if ismini:
        return dict(
        type=dataset_type,
        ann_file=data_root + 'VOC2007/ImageSets/Main/mini_test.txt',
        img_prefix=data_root + 'VOC2007/',
        pipeline=test_pipeline)
    else:
        return dict(
        type=dataset_type,
        ann_file=data_root + 'VOC2007/ImageSets/Main/test.txt',
        img_prefix=data_root + 'VOC2007/',
        pipeline=test_pipeline)

def ConfigDatasetTESTCustom(custom=False):
    if custom:
        return dict(
        type=dataset_type,
        ann_file=[data_root + 'VOC2007/ImageSets/Main/custom_test.txt',
                  data_root + 'VOC2012/ImageSets/Main/custom_test.txt'],
        img_prefix=data_root + 'VOC2007/',
        pipeline=test_pipeline)
    else:
        return dict(
        type=dataset_type,
        ann_file=data_root + 'VOC2007/ImageSets/Main/test.txt',
        img_prefix=data_root + 'VOC2007/',
        pipeline=test_pipeline)

def DelJunkSave(work_dir):
    for file in os.listdir(work_dir):
        if '.pth' in file:
            os.remove(os.path.join(work_dir, file))

def SamplingViaLoader(data_loader, index):
    for i, data in enumerate(data_loader):
        if i != index:
            continue
        print(f'The loader is returning {i}st batch.')

        return data

def mmm(tensor):
    return tensor.max().item(), tensor.mean().item(), tensor.min().item(), tensor.std().item()

def ShannonEnt(prob):
    return -prob*prob.log()

def EditCfg(cfg, new_dict):
    print(' ===== EditCfg is ... ===== ')
    print(new_dict)
    for key, val in new_dict.items():
        if key in cfg:
            cfg[key] = val

def Shape(container):
    return [i.shape for i in container]

def ExtractAggFunc(type):
    splitTypes = type.split('_')
    funcDict = {'Sum': torch.sum, 'Avg': torch.mean, 'Max': torch.max}
    names = ['object', 'scale', 'class']
    output = {}
    for name in names:
        for splitType in splitTypes:
            if name in splitType:
                funcName = splitType.replace(name,'')
                func = funcDict[funcName]
                output[name] = func
    return output

def StartEnd(mlvl, sIdx):
    start, end = 0, 0
    for si, slvl in enumerate(mlvl):
        end = end + slvl.size(1)
        if si == sIdx:
            return start, end
        start = end

def DeleteImgs(path):
    if not os.path.exists(path):
        os.mkdir(path)
    eval_list = os.listdir(path)
    for i in eval_list:
        if '.jpg' in i:
            os.remove(os.path.join(path,i))

def ShowUncZero(Uncs, dataset_al, model, cfg, **kwargs):
    DeleteImgs('visualization')
    zeroIdx = (Uncs == 0).nonzero()[0]
    for idx in tqdm(zeroIdx):
        EasyOne = dataset_al[idx]
        VisualizeData(EasyOne, model, cfg, name=f'zero_{idx}', fromLoader=False, **kwargs)

def FindRunDir(name, path='/drive2/YH/[MX32]Active_Tracking_Project/[MyCodes]/MMdet_study/MIAOD_based_AOD/wandb/'):
    dirList = os.listdir(path)
    for dir in dirList:
        if name in dir:
            return dir

def getMaxConf(mlvl_cls_scores, nCls):
    B = mlvl_cls_scores[0].size(0)
    nScale = len(mlvl_cls_scores)
    device = mlvl_cls_scores[0].device
    output = torch.zeros(B, nScale).to(device)
    for sIdx, cls_scores in enumerate(mlvl_cls_scores):
        bar = cls_scores.permute([0,2,3,1]).reshape(B, -1, nCls)
        maxprob = bar.softmax(dim=-1).reshape(B,-1).max(dim=-1)[0]
        output[:,sIdx] = maxprob
    return output.max(dim=-1)[0].tolist(), output

def ResumeCycle(cfg, currentCycle, fromStartCycle):
    if currentCycle < fromStartCycle:
        return (False, False)
    X_L = np.load(cfg.work_dir + '/X_L_' + str(fromStartCycle) + '.npy')
    X_U = np.load(cfg.work_dir + '/X_U_' + str(fromStartCycle) + '.npy')
    return (X_L, X_U)

def ResumeCycle_WorkDir(work_dir, currentCycle, fromStartCycle):
    if currentCycle < fromStartCycle:
        return (False, False)
    X_L = np.load(work_dir + '/X_L_' + str(fromStartCycle) + '.npy')
    X_U = np.load(work_dir + '/X_U_' + str(fromStartCycle) + '.npy')
    return (X_L, X_U)

def append_dropout(model, rate=0.1):
    for name, module in model.named_children():
        if len(list(module.children())) > 0:
            append_dropout(module)
        if isinstance(module, nn.ReLU):
            new = nn.Sequential(module, nn.Dropout2d(p=rate))
            setattr(model, name, new)

def activate_dropout(model):
    for name, module in model.named_children():
        if len(list(module.children())) > 0:
            append_dropout(module)
        if isinstance(module, nn.Dropout2d):
            module.train()

def FlattenNcls(ten, idx = 0, nCls=20):
    out = ten[0][idx].permute(1,2,0).reshape(-1,nCls)
    if len(ten) > 1:
        for j in range(1,len(ten)):
            tmp = ten[j][idx].permute(1,2,0).reshape(-1,nCls)
            out = torch.cat((out,tmp), dim=0)
    return out

# def read_data(dataloader, labels=True):
#     if labels:
#         while True:
#             for img, label in dataloader:
#                 yield img, label
#     else:
#         while True:
#             for img, _ in dataloader:
#                 yield img

def read_data(dataloader, labels=True):
    if labels:
        while True:
            for img, label, _ in dataloader:
                yield img, label
    else:
        while True:
            for img, _, _ in dataloader:
                yield img

def read_data2(dataloader, labels=True):
    if labels:
        while True:
            for img, label in dataloader:
                yield img, label
    else:
        while True:
            for img, _ in dataloader:
                yield img

def Num2Color(num):
    num2color = {0: 'red', 1: 'blue', 2: 'yellow', 3: 'green', 4: 'purple', 5: 'cyan', 6: 'black',
                 7: 'orange', 8: 'pink', 9: 'skyblue', 10:'darkkhaki', 11:'mediumslateblue', 12:'darkslategray'}
    if torch.is_tensor(num): num = num.int().item()
    return num2color[num]

def drawTSNE2(feat1, feat2, label1=None, label2=None, name1 = None, name2 = None):
    mix_embeddings = torch.cat((feat1, feat2))
    TSNE_model = TSNE(n_components=2)
    TSNE_data = TSNE_model.fit_transform(mix_embeddings.detach().cpu().numpy())
    len1, len2 = len(feat1), len(feat2)

    if label1 == None or label2 == None:
        label1 = torch.ones(len(feat1)).to(feat1.device) * 10
        label2 = torch.ones(len(feat2)).to(feat2.device) * 11
    if name1 == None or name2 == None:
        name1 = 'feat1'; name2 = 'feat2'

    plt.scatter(x=TSNE_data[:len1, 0], y=TSNE_data[:len1, 1],
                c=list(map(Num2Color, label1)), marker='o', alpha=1.0, edgecolors='black', label=name1)
    plt.scatter(x=TSNE_data[len1:, 0], y=TSNE_data[len1:, 1],
                c=list(map(Num2Color, label2)), marker='x', alpha=1.0, edgecolors='black', label=name2)
    plt.legend()
    plt.show()

def drawTSNE3(feat1, feat2, feat3, label1=None, label2=None, label3=None, name1=None, name2=None, name3 = None):
    mix_embeddings = torch.cat((feat1, feat2, feat3))
    TSNE_model = TSNE(n_components=2)
    TSNE_data = TSNE_model.fit_transform(mix_embeddings.detach().cpu().numpy())
    len1, len2, len3 = len(feat1), len(feat2), len(feat3)

    if label1 == None: label1 = torch.ones(len(feat1)).to(feat1.device) * 10
    if label2 == None: label2 = torch.ones(len(feat2)).to(feat2.device) * 11
    if label3 == None: label3 = torch.ones(len(feat3)).to(feat3.device) * 12

    if name1 == None or name2 == None or name3 == None:
        name1 = 'feat1'
        name2 = 'feat2'
        name3 = 'feat3'

    plt.scatter(x=TSNE_data[len1 + len2:, 0], y=TSNE_data[len1 + len2:, 1],
                c=list(map(Num2Color, label3)), marker='x', alpha=0.4, edgecolors='black', label=name3)
    plt.scatter(x=TSNE_data[:len1, 0], y=TSNE_data[:len1, 1],
                c=list(map(Num2Color, label1)), marker='o', alpha=1.0, edgecolors='black', label=name1)
    plt.scatter(x=TSNE_data[len1:len1+len2, 0], y=TSNE_data[len1:len1+len2, 1],
                c=list(map(Num2Color, label2)), marker='*', alpha=1.0, edgecolors='black', label=name2)

    plt.legend()
    plt.show()

def Lmean(List):
    return torch.tensor(List).mean().item()

def UL_CLSACC(UL_acc, UL_labels): # Both are numpy arrays
    nC = len(np.unique(UL_labels))
    ClsACC = np.zeros(nC)
    for C in range(nC):
        ClsACC[C] = UL_acc[UL_labels == C].mean()

    return ClsACC

def KLD(Prob1, Prob2):
    return Prob1 * (Prob1.log() - Prob2.log())

def JSD(Prob1, Prob2):
    M = 0.5 * (Prob1 + Prob2)
    KL1 = KLD(Prob1, M).sum(dim=-1)
    KL2 = KLD(Prob2, M).sum(dim=-1)
    JSD = 0.5 * (KL1 + KL2)

    return JSD

def JSD_torch(Prob1, Prob2):
    M = 0.5 * (Prob1 + Prob2)
    KL1 = F.kl_div(Prob1, M, reduction='none').sum(dim=-1)
    KL2 = F.kl_div(Prob2, M, reduction='none').sum(dim=-1)
    JSD = 0.5 * (KL1 + KL2)

    return JSD

def Topk_ProbDiff(probs_L, probs_UL, _k):
    TopkClass = probs_L.topk(dim=1, k=_k)[1]
    Topk_ProbL = probs_L.gather(1, TopkClass)
    Topk_ProbUL = probs_UL.gather(1, TopkClass)
    Diff = (Topk_ProbL - Topk_ProbUL).abs()

    return Diff.sum(dim=-1)

def KLD_MVN(mean1, var1, mean2, var2):
    # every shape is (1 x n)
    mean1, mean2 = mean1.unsqueeze(dim=0), mean2.unsqueeze(dim=0)
    n = mean1.size(1)
    var1, var2 = var1.diag(), var2.diag()
    first = (mean2 - mean1) @ var2.inverse() @ (mean2 - mean1).transpose(1,0)
    second = (var2.inverse() @ var1).trace()
    third = torch.log(var1.det() / var2.det())
    out = 0.5 * (first + second - third - n)

    return out

def ALL_KLD_MVN(L_ALL_GMM_mean, L_ALL_GMM_var, Ul_ALL_GMM_mean, Ul_ALL_GMM_var):
    nC = L_ALL_GMM_mean.size(0)
    out = torch.zeros(nC).to(L_ALL_GMM_mean.device)
    for C in range(nC):
        l_mean, l_var = L_ALL_GMM_mean[C], L_ALL_GMM_var[C]
        ul_mean, ul_var = Ul_ALL_GMM_mean[C], Ul_ALL_GMM_var[C]
        KLD_L_UL = KLD_MVN(l_mean, l_var, ul_mean, ul_var).squeeze()
        KLD_UL_L = KLD_MVN(ul_mean, ul_var, l_mean, l_var).squeeze()
        out[C] = 0.5 * (KLD_L_UL + KLD_UL_L)

    return out

def inDict(Dict, elem):
    assert isinstance(Dict, dict)

    if elem in Dict:
        return Dict[elem]
    else:
        return False

def GMM_LogLikelihood2(GMM_model, embedding, sample_size, pi, mean, logvar, **kwargs):
    assert mean.shape == logvar.shape
    batch_size, latent_size = embedding.shape

    _log_likelihoods = GMM_model.gaussian_log_prob(
        embedding[:, None, :].repeat(1, GMM_model.component_size, 1), # [100,10,512]
        mean[None, :, :].repeat(batch_size, 1, 1), # [100,10,512]
        logvar[None, :, :].repeat(batch_size, 1, 1), **kwargs # [100,10,512]
    ) # [100,10]

    if _log_likelihoods == None:
        return None
    else:
        log_likelihoods = _log_likelihoods + torch.log(pi[None, :].repeat(batch_size, 1))
        LogSumExp = torch.logsumexp(log_likelihoods , dim=-1)

    return LogSumExp

def GetDistInfo(psi_xl, psi_xul, embedding, GMM_model, labeled_embedding, unlabeled_embedding):
    batch_size, latent_size, sample_size = embedding.shape[0], embedding.shape[-1], 16
    L_GMM_pi, L_GMM_mean, L_GMM_var = psi_xl
    UL_GMM_pi, UL_GMM_mean, UL_GMM_var = psi_xul
    all_embedding = torch.cat((labeled_embedding, unlabeled_embedding), dim=0)
    all_mean, all_std = all_embedding.mean(dim=0).mean(), all_embedding.std(dim=0).mean()
    # ============================= DistInfo w.r.t. labeled set (X_L) ==================================#
    # L_x_log_like_*_prior = p(x|c_i;psi_L)p(c_i) // p(c_i) = pi_i
    _L_log_like_x_prior = GMM_model.gaussian_log_prob(  # [100,16,10]
        embedding[:, :, None, :].repeat(1, 1, GMM_model.component_size, 1),
        L_GMM_mean[None, None, :, :].repeat(batch_size, sample_size, 1, 1), L_GMM_var
    ) + torch.log(L_GMM_pi[None, None, :].repeat(batch_size, sample_size, 1))
    # L_x_log_marginal = p(x;psi_L) = sum_i{p(x|c_i;psi_L)p(c_i)}
    _L_log_marginal = torch.logsumexp(_L_log_like_x_prior, dim=-1, keepdim=True)  # [100,16,1]
    # ============================= DistInfo w.r.t. unlabeled set (X_UL) ===============================#
    # UL_x_log_like_x_prior = p(x|c_i;psi_UL)p(c_i) // p(c_i) = pi_i
    _UL_log_like_x_prior = GMM_model.gaussian_log_prob(  # [100,16,10]
        embedding[:, :, None, :].repeat(1, 1, GMM_model.component_size, 1),
        UL_GMM_mean[None, None, :, :].repeat(batch_size, sample_size, 1, 1), UL_GMM_var
    ) + torch.log(UL_GMM_pi[None, None, :].repeat(batch_size, sample_size, 1))
    # UL_x_log_marginal = p(x;psi_UL) = sum_i{p(x|c_i;psi_UL)p(c_i)}
    _UL_log_marginal = torch.logsumexp(_UL_log_like_x_prior, dim=-1, keepdim=True)  # [100,16,1]

    L_log_like_x_prior = _L_log_like_x_prior.mean(dim=-2)  # [100,10]
    L_log_marginal = _L_log_marginal.mean(dim=-2)  # [100,1]
    UL_log_like_x_prior = _UL_log_like_x_prior.mean(dim=-2)  # [100,10]
    UL_log_marginal = _UL_log_marginal.mean(dim=-2)  # [100,1]
    distInfo = torch.cat((L_log_like_x_prior, L_log_marginal, UL_log_like_x_prior, UL_log_marginal), dim=1)
    normed_distInfo = (distInfo - distInfo.mean()) / (distInfo.std() + 1e-9)
    scaled_distInfo = all_std * normed_distInfo + all_mean

    distCat = torch.cat((embedding.mean(dim=-2), scaled_distInfo), dim=1)
    return distCat

def GetDistInfo_V2(psi_xl, psi_xul, embedding, GMM_model):
    batch_size, latent_size, sample_size = embedding.shape[0], embedding.shape[-1], 16
    L_GMM_pi, L_GMM_mean, L_GMM_var = psi_xl
    UL_GMM_pi, UL_GMM_mean, UL_GMM_var = psi_xul
    tar_mean, tsr_std = embedding.mean(dim=0).mean(), embedding.std(dim=0).mean()
    # ============================= DistInfo w.r.t. labeled set (X_L) ==================================#
    # L_x_log_like_*_prior = p(x|c_i;psi_L)p(c_i) // p(c_i) = pi_i
    _L_log_like_x_prior = GMM_model.gaussian_log_prob(  # [100,16,10]
        embedding[:, :, None, :].repeat(1, 1, GMM_model.component_size, 1),
        L_GMM_mean[None, None, :, :].repeat(batch_size, sample_size, 1, 1), L_GMM_var
    ) + torch.log(L_GMM_pi[None, None, :].repeat(batch_size, sample_size, 1))
    # L_x_log_marginal = p(x;psi_L) = sum_i{p(x|c_i;psi_L)p(c_i)}
    _L_log_marginal = torch.logsumexp(_L_log_like_x_prior, dim=-1, keepdim=True)  # [100,16,1]
    # ============================= DistInfo w.r.t. unlabeled set (X_UL) ===============================#
    # UL_x_log_like_x_prior = p(x|c_i;psi_UL)p(c_i) // p(c_i) = pi_i
    _UL_log_like_x_prior = GMM_model.gaussian_log_prob(  # [100,16,10]
        embedding[:, :, None, :].repeat(1, 1, GMM_model.component_size, 1),
        UL_GMM_mean[None, None, :, :].repeat(batch_size, sample_size, 1, 1), UL_GMM_var
    ) + torch.log(UL_GMM_pi[None, None, :].repeat(batch_size, sample_size, 1))
    # UL_x_log_marginal = p(x;psi_UL) = sum_i{p(x|c_i;psi_UL)p(c_i)}
    _UL_log_marginal = torch.logsumexp(_UL_log_like_x_prior, dim=-1, keepdim=True)  # [100,16,1]

    L_log_like_x_prior = _L_log_like_x_prior.mean(dim=-2)  # [100,10]
    L_log_marginal = _L_log_marginal.mean(dim=-2)  # [100,1]
    UL_log_like_x_prior = _UL_log_like_x_prior.mean(dim=-2)  # [100,10]
    UL_log_marginal = _UL_log_marginal.mean(dim=-2)  # [100,1]
    distInfo = torch.cat((L_log_like_x_prior, L_log_marginal, UL_log_like_x_prior, UL_log_marginal), dim=1)
    normed_distInfo = (distInfo - distInfo.mean()) / (distInfo.std() + 1e-9)
    scaled_distInfo = tsr_std * normed_distInfo + tar_mean

    distCat = torch.cat((embedding.mean(dim=-2), scaled_distInfo), dim=1)
    return distCat

def MinMaxSqrt(bar):
    bar = (bar - bar.min()) / (bar.max() - bar.min())
    bar = (bar + 0.01) ** 0.2
    return bar

def MinMax(bar):
    bar = (bar - bar.min()) / (bar.max() - bar.min())
    return bar

def WriteCycle(state_dict, saveName, cycle):
    cycleName = f'{saveName}_{cycle}cycle.pth'
    torch.save(state_dict, cycleName)

def OverWriteCycle(state_dict, saveName, cycle):
    cycleName = f'{saveName}_{cycle}cycle.pth'
    prev_cycleName = f'{saveName}_{cycle-1}cycle.pth'
    if os.path.exists(prev_cycleName):
        os.remove(prev_cycleName)
    torch.save(state_dict, cycleName)

def Ranking(data):
    return np.searchsorted(sorted(data), data)

def RankingUnion(set1, set2):
    RankUnion = Ranking(np.concatenate((set1, set2)))
    out1, out2 = RankUnion[:len(set1)], RankUnion[len(set1):]
    return out1, out2

def GaussianRanking(data):
    ranked_data = Ranking(data)
    normalized_data = (ranked_data - ranked_data.mean()) / ranked_data.std()
    return normalized_data

def PlotGMM(psi, embedding, str=''):
    import pandas as pd
    proto = psi[1]
    mix_embeddings = torch.cat((embedding, proto))

    print("TSNE starts..!")  # mix = 1000 + 5000 + 10 + 10
    TSNE_model = TSNE(n_components=2, perplexity=50, n_iter=1500)
    TSNE_data = TSNE_model.fit_transform(mix_embeddings.detach().cpu().numpy())

    colorList = ['red', 'blue', 'yellow', 'green', 'purple', 'cyan', 'black', 'orange', 'navy', 'skyblue']

    TSNE_df = pd.DataFrame(
        {'x': TSNE_data[:-10, 0], 'y': TSNE_data[:-10, 1], 'alpha': 1.0, 'class': 'silver', 'labeled': ['1'] * len(TSNE_data[:-10])})
    TSNE_Pdf = pd.DataFrame(
        {'x': TSNE_data[-10:, 0], 'y': TSNE_data[-10:, 1], 'alpha': 1.0, 'class': colorList, 'labeled': ['2'] * 10})

    plt.figure(figsize=(16, 10))
    plt.scatter(x=TSNE_df['x'], y=TSNE_df['y'], c=TSNE_df['class'], marker='o', alpha=1.0,
                edgecolors='black', label='Unlabeled', cmap="coolwarm")
    plt.scatter(x=TSNE_Pdf['x'], y=TSNE_Pdf['y'], c=TSNE_Pdf['class'], marker='*', alpha=1.0, s=500,
                edgecolors='black', label='L_Proto')
    plt.legend()
    plt.title('PlotGMM' + str)
    # plt.show()
    plt.savefig(f'tmp/GMM/PlotGMM_{str}.jpg')

def BDMatrix(means1, logvars1, means2, logvars2):
    nC = means1.size(0)
    out = torch.ones((nC,nC)).to(means1.device)
    for r in range(nC):
        for c in range(nC):
            mean1, mean2 = means1[r], means2[c]
            logvar1, logvar2 = logvars1[r], logvars2[c]
            avglogvar = torch.logsumexp(torch.stack((logvar1, logvar2)), dim=0) - torch.log(torch.tensor(2.0))
            dist = BD(mean1, logvar1, mean2, logvar2, avglogvar)
            out[r][c] = dist

    return out

def BD(mean1, logvar1, mean2, logvar2, avglogvar): # Bhattacharryya_Distance between two Gaussian

    logvar1, logvar2 = torch.zeros_like(logvar1), torch.zeros_like(logvar2)
    avglogvar = torch.zeros_like(avglogvar)

    avgCov = torch.diag(avglogvar.exp())
    term1 = 1/8 * (mean1-mean2) @ avgCov @ (mean1-mean2)
    term2 = 1/2 * (avglogvar.sum() - 1/2 * logvar1.sum() - 1/2 * logvar2.sum())

    distance = term1 + term2
    return distance