import os
import sys
import torch
import torchvision
import numpy as np
from torchvision import datasets, models, transforms
import torch.nn as nn
from sklearn.metrics import roc_auc_score, roc_curve, auc
from torch.autograd import Variable 
import time
from sklearn.metrics import confusion_matrix, roc_curve, auc
import datetime
import shutil

def cprint(color, text, **kwargs):
    if color[0] == '*':
        pre_code = '1;'
        color = color[1:]
    else:
        pre_code = ''
    code = {
        'a': '30',
        'r': '31',
        'g': '32',
        'y': '33',
        'b': '34',
        'p': '35',
        'c': '36',
        'w': '37'
    }
    print("\x1b[%s%sm%s\x1b[0m" % (pre_code, code[color], text), **kwargs)
    sys.stdout.flush()

def comp_avg_metrics(test_metrics, sele_metric, ascend=True, avg_num=2, name='loss'):
    '''
    :param test_metrics: metrics of testing data set (ACC, SEN, SPE, AUC)
    :param sele_metric: selected metric
    :param ascend: True----ascend; False----descend
    :param avg_num: number of selected models
    :param name: selected metrics for evaluation
    :return:
    '''
    if ascend == True:
        index = np.argsort(sele_metric.numpy().ravel())[:avg_num] # ascend order is True, e.g., losses
    else:
        index = np.argsort(np.array(sele_metric))[-avg_num:] # ascend order is False(descend order), e.g., accuracies
        index = index[::-1]
    cprint('r', '\nSelected models indices:', end=' ')   # print selected models
    def printFun(x): print(x, end=' ')
    printFunRun = [printFun(x+1) for x in index] # 1 ~ epoch num
    cprint('r', '\nAverage results based ' + name)

    acc11 = 0
    sen11 = 0
    spe11 = 0
    auc11 = 0
    tt = 0
    metrics = ['acc', 'sen', 'spe', 'auc']
    for metric in metrics:
        cur_metric = [test_metrics[ind][metric] for ind in index]
        print('       {:s} ---- {:.2f} ± {:.2f}'.format(str.upper(metric), np.mean(cur_metric), np.std(cur_metric)))
        if tt == 0:acc11 = round(np.mean(cur_metric),2)
        if tt == 1: sen11 = round(np.mean(cur_metric),2)
        if tt == 2: spe11 = round(np.mean(cur_metric),2)
        if tt == 3: auc11 = round(np.mean(cur_metric),2)
        tt = tt + 1
    return acc11,sen11,spe11,auc11,index

def metrics_print(metrics_dict):
    for key, value in metrics_dict.items():
        print("{:s}: {:.3f}".format(str.upper(key), value), end=' ')


def compute_measures(labels, preds, predict_prob):
    '''
    :param labels: ground truth
    :param preds: predicted label
    :param predict_prob: postive probability (1)
    :return: ACC, SEN, SPE, AUC
    '''
    labels = labels.cpu().numpy()
    preds = preds.cpu().numpy()
    predict_prob = predict_prob.cpu().detach()
    cm = confusion_matrix(labels, preds)
    TP = float(cm[1][1])
    TN = float(cm[0][0])
    FP = float(cm[0][1])
    FN = float(cm[1][0])

    sensitivity = TP / ((TP + FN) + 1e-8)
    specificity = TN / (TN + FP + 1e-8)
    accuracy = (TP + TN) / (TP + TN + FP + FN)
    fpr, tpr, thresholds = roc_curve(labels, predict_prob)
    roc_auc = auc(fpr, tpr)
    index = [round(accuracy,4), round(sensitivity,4), round(specificity,4), round(roc_auc,4)]
    res_index = dict(zip(['acc', 'sen', 'spe', 'auc'], index))
    return res_index

class MlpBlock(nn.Module):
    def __init__(self,in_dim,hidden_dim,drop_rate=0):
        super(MlpBlock,self).__init__()
        self.mlp=nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(drop_rate),
            nn.Linear(hidden_dim, in_dim),
            nn.Dropout(drop_rate))
    def forward(self,x):
        return self.mlp(x)

class MixerLayer(nn.Module):

    def __init__(self,ns,nc,ds,dc,drop_rate):
        super(MixerLayer,self).__init__()
        self.norm1=nn.LayerNorm(nc)
        self.norm2=nn.LayerNorm(nc)
        self.tokenMix=MlpBlock(in_dim=ns,hidden_dim=ds,drop_rate=drop_rate)
        self.channelMix=MlpBlock(in_dim=nc,hidden_dim=dc,drop_rate=drop_rate)

    def forward(self,x):
        x=self.norm1(x)
        x2=self.tokenMix(x.transpose(1,2)).transpose(1,2)
        x=x+x2
        x2=self.norm2(x)
        x2=self.channelMix(x2)
        return x+x2

class Mixer(nn.Module):
    def __init__(self,num_classes,image_size,patch_size,num_layers,embed_dim,ds,dc):
        
        super(Mixer,self).__init__()
        assert image_size%patch_size==0
        self.embed = nn.Conv2d(1,embed_dim,kernel_size=patch_size,stride=patch_size) 
        ns=(image_size//patch_size)**2 
        MixBlock=MixerLayer(ns=ns,nc=embed_dim,ds=ds,dc=dc,drop_rate=0)
        self.mixlayers=nn.Sequential(*[MixBlock for _ in range(num_layers)])
        self.norm=nn.LayerNorm(embed_dim)
        self.cls=nn.Linear(embed_dim,num_classes)

    def forward(self,x):
        x=self.embed(x).flatten(2).transpose(1,2) # n c2 hw->n hw c2
        x=self.mixlayers(x)
        x=self.norm(x)
        x=torch.mean(x,dim=1)
        x=self.cls(x)
        return x

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Sequential( # input shape (1, 224, 224)
            nn.Conv2d(1, 8, 5, 1, 2),
            nn.ReLU(), # activation
            nn.MaxPool2d(2), 
        )
        self.conv2 = nn.Sequential( # input shape (8, 112, 112)
            nn.Conv2d(8, 16, 5, 1, 2),
            nn.ReLU(), # activation
            nn.MaxPool2d(2), 
        )
        self.conv3 = nn.Sequential( # input shape (16, 56, 56)
            nn.Conv2d(16,24, 5, 1, 2),
            nn.ReLU(), # activation
            nn.MaxPool2d(2),
        )
        self.conv4 = nn.Sequential( # input shape (24, 28, 28)
            nn.Conv2d(24,32, 5, 1, 2),
            nn.ReLU(), # activation
            nn.MaxPool2d(2),
        )
        self.conv5 = nn.Sequential( # input shape (32, 14, 14)
            nn.Conv2d(32,40, 5, 1, 2),
            nn.ReLU(), # activation
            nn.MaxPool2d(2), 
        )
        self.out = nn.Linear(40 * 7 * 7, 2) # fully connected layer, output 3 classes

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        x = x.view(x.size(0), -1) # flatten the output of conv2 to (batch_size, 40 * 7 * 7)
        output = self.out(x)
        return output # return x for visualization

def ini(it0, NPOP, train_load, test_load, epoch_num, LR, image_size, patch_size, trainset, testset):
    # ======================================================================================================================
    # device detection
    # ======================================================================================================================
    print("PYTORCH's version is ", torch.__version__)
    os.environ['CUDA_VISIBLE_DEVICES']='4'
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    print('device: ', device)

    f = np.zeros((len(NPOP), 2))
    obj_f = np.zeros((len(NPOP), 1))
    AUC_f = np.zeros((len(NPOP), 1))
    for t in range(0, len(NPOP)):
        layer_number = int(NPOP[t, 0])
        embed_dim = int(NPOP[t, 1])
        ds = int(NPOP[t, 2])
        dc = int(NPOP[t, 3])

        # ======================================================================================================================
        # model building
        # ======================================================================================================================
        model = Mixer(num_classes=2,image_size=image_size,patch_size=patch_size,num_layers=layer_number,embed_dim=embed_dim,ds=ds,dc=dc)
        #model = CNN()
        model = model.to(device)

        loss_fn = nn.CrossEntropyLoss() 
        optimizer = torch.optim.SGD(model.parameters(), lr=LR, momentum=0.9)#torch.optim.Adam(model.parameters(), lr=LR)#
        m = nn.Softmax(dim=1)

        train_labels = torch.zeros(len(trainset), 1)
        train_preds = torch.zeros(len(trainset), 1)
        train_probs = torch.zeros(len(trainset), 1)
        train_losses = torch.zeros(epoch_num, 1)
        train_metrics = []

        test_labels = torch.zeros(len(testset), 1)
        test_preds = torch.zeros(len(testset), 1)
        test_probs = torch.zeros(len(testset), 1)
        test_losses = torch.zeros(epoch_num, 1)
        test_metrics = []

        result_all = {}

        # ======================================================================================================================
        # model training, validation, testing
        # ======================================================================================================================
        for i in range(1, epoch_num + 1):
            print('\n')
            print("\033[1;32m    ini [{}] \033[0m".format(it0))
            print("\033[1;32m    Epoch [{}/{}] \033[0m".format(i, epoch_num))
            print("-" * 45)

            # model training
            model.train()
            train_loss = 0.0
            cprint('r', "model is training...")
            index = 0
            for X, y, in train_load:
                X, y = Variable(X.to(device)), Variable(y.to(device)) # to device
                pred = model(X)                     # prediction, logits
                probs = m(pred)                     # probability
                _, pred_y = torch.max(probs, 1)     # predicted label
                optimizer.zero_grad()
                loss = loss_fn(pred, y)
                loss.backward()                     # backward pass
                optimizer.step()                    # weights update
                train_loss += len(y) * loss.item()

                train_labels[index:index+len(y), 0] = y
                train_preds[index:index+len(y), 0] = pred_y
                train_probs[index:index+len(y), 0] = probs[:,1]

                index += len(y)

            print('train loss :', round(1.0*train_loss/len(trainset), 4))
            train_metric = compute_measures(train_labels, train_preds, train_probs)
            train_losses[i - 1, 0] = round(1.0 * train_loss / len(trainset), 4)
            train_metrics.append(train_metric)
            metrics_print(train_metric)

            # model testing
            model.eval()
            test_loss = 0.0
            cprint('r', "\nmodel is testing...")
            index = 0
            with torch.no_grad():
                for X, y, in test_load:
                    X, y = Variable(X.to(device)), Variable(y.to(device))
                    pred = model(X)
                    probs = m(pred)
                    _, pred_y = torch.max(probs, 1)
                    loss = loss_fn(pred, y)
                    test_loss += len(y) * loss.item()

                    test_labels[index:index + len(y), 0] = y
                    test_preds[index:index + len(y), 0] = pred_y
                    test_probs[index:index + len(y), 0] = probs[:, 1]

                    index += len(y)
            print('test loss :', round(1.0 * test_loss / len(testset), 4))
            test_metric = compute_measures(test_labels, test_preds, test_probs)
            test_metrics.append(test_metric)
            test_losses[i - 1, 0] = round(1.0 * test_loss / len(testset), 4)
            metrics_print(test_metric)
        
        test_metrics2 = []
        for x in test_metrics:
            if (x['sen'] >= 0.8) & (x['spe'] >= 0.8):
                test_metrics2.append(x)

        if len(test_metrics2) == 0:
            for x in test_metrics:
                if (x['sen'] >= 0.75) & (x['spe'] >= 0.75):
                    test_metrics2.append(x)

        if len(test_metrics2) == 0:
            for x in test_metrics:
                if (x['sen'] >= 0.7) & (x['spe'] >= 0.7):
                    test_metrics2.append(x)
        
        if len(test_metrics2) == 0:
            for x in test_metrics:
                if (x['sen'] >= 0.65) & (x['spe'] >= 0.65):
                    test_metrics2.append(x)

        if len(test_metrics2) == 0:
            for x in test_metrics:
                if (x['sen'] >= 0.6) & (x['spe'] >= 0.6):
                    test_metrics2.append(x)

        if len(test_metrics2) == 0:
            for x in test_metrics:
                if (x['sen'] >= 0.55) & (x['spe'] >= 0.55):
                    test_metrics2.append(x)
        
        test_auc = [x['auc'] for x in test_metrics2]
        acc, sen, spe, auc, index2 = comp_avg_metrics(test_metrics2, test_auc, ascend=False, avg_num=1,
                                                     name='auc') 
        f[t][0] = sen
        f[t][1] = spe
        obj_f[t] = acc
        AUC_f[t] = auc

    return f, obj_f, AUC_f