import os
import sys
import torch
import torchvision
import numpy as np
from torchvision import datasets, models, transforms
import torch.nn as nn
from sklearn.metrics import roc_auc_score, roc_curve, auc
from torch.autograd import Variable
import time
from sklearn.metrics import confusion_matrix, roc_curve, auc
import datetime
import shutil
import copy

def cprint(color, text, **kwargs):
    if color[0] == '*':
        pre_code = '1;'
        color = color[1:]
    else:
        pre_code = ''
    code = {
        'a': '30',
        'r': '31',
        'g': '32',
        'y': '33',
        'b': '34',
        'p': '35',
        'c': '36',
        'w': '37'
    }
    print("\x1b[%s%sm%s\x1b[0m" % (pre_code, code[color], text), ** kwargs)
    sys.stdout.flush()


def comp_avg_metrics(test_metrics, sele_metric, ascend=True, avg_num=2, name='loss'):
    '''
    :param test_metrics: metrics of testing data set (ACC, SEN, SPE, AUC)
    :param sele_metric: selected metric
    :param ascend: True----ascend; False----descend
    :param avg_num: number of selected models
    :param name: selected metrics for evaluation
    :return:
    '''
    if ascend == True:
        index = np.argsort(sele_metric.numpy().ravel())[:avg_num] # ascend order is True, e.g., losses
    else:
        index = np.argsort(np.array(sele_metric))[-avg_num:] # ascend order is False(descend order), e.g., accuracies
        index = index[::-1]
    cprint('r', '\nSelected models indices:', end=' ')   # print selected models
    def printFun(x): print(x, end=' ')
    printFunRun = [printFun(x+1) for x in index] # 1 ~ epoch num
    cprint('r', '\nAverage results based ' + name)

    acc11 = 0
    sen11 = 0
    spe11 = 0
    auc11 = 0
    tt = 0
    metrics = ['acc', 'sen', 'spe', 'auc']
    for metric in metrics:
        cur_metric = [test_metrics[ind][metric] for ind in index]
        print('       {:s} ---- {:.2f} ± {:.2f}'.format(str.upper(metric), np.mean(cur_metric), np.std(cur_metric)))
        if tt == 0:acc11 = round(np.mean(cur_metric),2)
        if tt == 1: sen11 = round(np.mean(cur_metric),2)
        if tt == 2: spe11 = round(np.mean(cur_metric),2)
        if tt == 3: auc11 = round(np.mean(cur_metric),2)
        tt = tt + 1
    return acc11,sen11,spe11,auc11,index


def metrics_print(metrics_dict):
    for key, value in metrics_dict.items():
        print("{:s}: {:.3f}".format(str.upper(key), value), end=' ')


def compute_measures(labels, preds, predict_prob):
    '''
    :param labels: ground truth
    :param preds: predicted label
    :param predict_prob: postive probability (1)
    :return: ACC, SEN, SPE, AUC
    '''
    labels = labels.cpu().numpy()
    preds = preds.cpu().numpy()
    predict_prob = predict_prob.cpu().detach()
    cm = confusion_matrix(labels, preds)
    TP = float(cm[1][1])
    TN = float(cm[0][0])
    FP = float(cm[0][1])
    FN = float(cm[1][0])

    sensitivity = TP / ((TP + FN) + 1e-8)
    specificity = TN / (TN + FP + 1e-8)
    accuracy = (TP + TN) / (TP + TN + FP + FN)
    fpr, tpr, thresholds = roc_curve(labels, predict_prob)
    roc_auc = auc(fpr, tpr)
    index = [round(accuracy, 4), round(sensitivity, 4), round(specificity, 4), round(roc_auc, 4)]
    res_index = dict(zip(['acc', 'sen', 'spe', 'auc'], index))
    return res_index


class MlpBlock(nn.Module):
    def __init__(self, in_dim, hidden_dim, drop_rate=0):
        super(MlpBlock, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(drop_rate),
            nn.Linear(hidden_dim, in_dim),
            nn.Dropout(drop_rate))

    def forward(self, x):
        return self.mlp(x)


class MixerLayer(nn.Module):

    def __init__(self, ns, nc, ds, dc, drop_rate):
        super(MixerLayer, self).__init__()
        self.norm1 = nn.LayerNorm(nc)
        self.norm2 = nn.LayerNorm(nc)
        self.tokenMix = MlpBlock(in_dim=ns, hidden_dim=ds, drop_rate=drop_rate)
        self.channelMix = MlpBlock(in_dim=nc, hidden_dim=dc, drop_rate=drop_rate)

    def forward(self, x):
        x = self.norm1(x)
        x2 = self.tokenMix(x.transpose(1, 2)).transpose(1, 2)
        x = x + x2
        x2 = self.norm2(x)
        x2 = self.channelMix(x2)
        return x + x2


class Mixer(nn.Module):
    def __init__(self, num_classes, image_size, patch_size, num_layers, embed_dim, ds, dc):

        super(Mixer, self).__init__()
        assert image_size % patch_size == 0
        self.embed = nn.Conv2d(1, embed_dim, kernel_size=patch_size, stride=patch_size) 
        ns = (image_size // patch_size) ** 2 
        MixBlock = MixerLayer(ns=ns, nc=embed_dim, ds=ds, dc=dc, drop_rate=0)
        self.mixlayers = nn.Sequential(*[MixBlock for _ in range(num_layers)])
        self.norm = nn.LayerNorm(embed_dim)
        self.cls = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        x = self.embed(x).flatten(2).transpose(1, 2)  # n c2 hw->n hw c2
        x = self.mixlayers(x)
        x = self.norm(x)
        x = torch.mean(x, dim=1) 
        x = self.cls(x)
        return x

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Sequential( # input shape (1, 224, 224)
            nn.Conv2d(1, 8, 5, 1, 2),
            nn.ReLU(), # activation
            nn.MaxPool2d(2), 
        )
        self.conv2 = nn.Sequential( # input shape (8, 112, 112)
            nn.Conv2d(8, 16, 5, 1, 2),
            nn.ReLU(), # activation
            nn.MaxPool2d(2), 
        )
        self.conv3 = nn.Sequential( # input shape (16, 56, 56)
            nn.Conv2d(16,24, 5, 1, 2),
            nn.ReLU(), # activation
            nn.MaxPool2d(2),
        )
        self.conv4 = nn.Sequential( # input shape (24, 28, 28)
            nn.Conv2d(24,32, 5, 1, 2),
            nn.ReLU(), # activation
            nn.MaxPool2d(2),
        )
        self.conv5 = nn.Sequential( # input shape (32, 14, 14)
            nn.Conv2d(32,40, 5, 1, 2),
            nn.ReLU(), # activation
            nn.MaxPool2d(2), 
        )
        self.out = nn.Linear(40 * 7 * 7, 2) # fully connected layer, output 3 classes

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        x = x.view(x.size(0), -1) # flatten the output of conv2 to (batch_size, 40 * 7 * 7)
        output = self.out(x)
        return output # return x for visualization

def computing_fitness_utility_new1(it,NPOP, train_load, test_load, epoch_num, LR, image_size, patch_size, trainset, testset):
    # ======================================================================================================================
    # device detection
    # ======================================================================================================================
    print("PYTORCH's version is ", torch.__version__)
    os.environ['CUDA_VISIBLE_DEVICES']='4'
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    print('device: ', device)
    
    f = np.zeros((len(NPOP), 2))
    obj_f = np.zeros((len(NPOP), 1))
    AUC_f = np.zeros((len(NPOP), 1))
    fin_label_indt = []
    fin_indt = []
    fin_modelt = []
    for t in range(0, len(NPOP)):
        layer_number = int(NPOP[t, 0])
        embed_dim = int(NPOP[t, 1])
        ds = int(NPOP[t, 2])
        dc = int(NPOP[t, 3])

        # ======================================================================================================================
        # model building
        # ======================================================================================================================
        model = Mixer(num_classes=2,image_size=image_size,patch_size=patch_size,num_layers=layer_number,embed_dim=embed_dim,ds=ds,dc=dc)
        model = model.to(device)

        loss_fn = nn.CrossEntropyLoss()  # loss function
        optimizer = torch.optim.SGD(model.parameters(), lr=LR,
                                    momentum=0.9)  # torch.optim.Adam(model.parameters(), lr=LR)#
        m = nn.Softmax(dim=1)

        train_labels = torch.zeros(len(trainset), 1)
        train_preds = torch.zeros(len(trainset), 1)
        train_probs = torch.zeros(len(trainset), 1)
        train_losses = torch.zeros(epoch_num, 1)
        train_metrics = []

        test_labels = torch.zeros(len(testset), 1)
        test_preds = torch.zeros(len(testset), 1)
        test_probs = torch.zeros(len(testset), 1)
        test_losses = torch.zeros(epoch_num, 1)
        test_metrics = []

        result_all = {}
        model_all = []

        # ======================================================================================================================
        # model training, validation, testing
        # ======================================================================================================================
        for i in range(1, epoch_num + 1):
            print('\n')
            print("\033[1;32m    it_num [{}] \033[0m".format(it))
            print("\033[1;32m    Epoch [{}/{}] \033[0m".format(i, epoch_num))
            print("-" * 45)

            # model training
            model.train()
            train_loss = 0.0
            cprint('r', "model is training...")
            index = 0
            for X, y, in train_load:
                X, y = Variable(X.to(device)), Variable(y.to(device))  # to device
                pred = model(X)  # prediction, logits
                probs = m(pred)  # probability
                _, pred_y = torch.max(probs, 1)  # predicted label
                optimizer.zero_grad()
                loss = loss_fn(pred, y)
                loss.backward()  # backward pass
                optimizer.step()  # weights update
                train_loss += len(y) * loss.item()

                train_labels[index:index + len(y), 0] = y
                train_preds[index:index + len(y), 0] = pred_y
                train_probs[index:index + len(y), 0] = probs[:, 1]

                index += len(y)

            print('train loss :', round(1.0 * train_loss / len(trainset), 4))
            train_metric = compute_measures(train_labels, train_preds, train_probs)
            train_losses[i - 1, 0] = round(1.0 * train_loss / len(trainset), 4)
            train_metrics.append(train_metric)
            metrics_print(train_metric)

            # model testing
            model.eval()
            test_loss = 0.0
            cprint('r', "\nmodel is testing...")
            index = 0
            with torch.no_grad():
                for X, y, in test_load:
                    X, y = Variable(X.to(device)), Variable(y.to(device))
                    pred = model(X)
                    probs = m(pred)
                    _, pred_y = torch.max(probs, 1)
                    loss = loss_fn(pred, y)
                    test_loss += len(y) * loss.item()

                    test_labels[index:index + len(y), 0] = y
                    test_preds[index:index + len(y), 0] = pred_y
                    test_probs[index:index + len(y), 0] = probs[:, 1]

                    index += len(y)
            print('test loss :', round(1.0 * test_loss / len(testset), 4))
            test_metric = compute_measures(test_labels, test_preds, test_probs)
            test_metrics.append(test_metric)
            test_losses[i - 1, 0] = round(1.0 * test_loss / len(testset), 4)
            metrics_print(test_metric)
            modelmodel = copy.deepcopy(model)
            model_all.append(modelmodel)
        
        test_metrics2 = []
        model_all2 = []
        tt = 0
        for x in test_metrics:
            if (x['sen'] >= 0.8) & (x['spe'] >= 0.8):
                test_metrics2.append(x)
                model_all2.append(model_all[tt])
            tt = tt + 1

        if len(test_metrics2) == 0:
            tt = 0
            for x in test_metrics:
                if (x['sen'] >= 0.75) & (x['spe'] >= 0.75):
                    test_metrics2.append(x)
                    model_all2.append(model_all[tt])
                tt = tt + 1

        if len(test_metrics2) == 0:
            tt = 0
            for x in test_metrics:
                if (x['sen'] >= 0.7) & (x['spe'] >= 0.7):
                    test_metrics2.append(x)
                    model_all2.append(model_all[tt])
                tt = tt + 1
        
        if len(test_metrics2) == 0:
            tt = 0
            for x in test_metrics:
                if (x['sen'] >= 0.65) & (x['spe'] >= 0.65):
                    test_metrics2.append(x)
                    model_all2.append(model_all[tt])
                tt = tt + 1
        
        if len(test_metrics2) == 0:
            tt = 0
            for x in test_metrics:
                if (x['sen'] >= 0.6) & (x['spe'] >= 0.6):
                    test_metrics2.append(x)
                    model_all2.append(model_all[tt])
                tt = tt + 1
        
        if len(test_metrics2) == 0:
            tt = 0
            for x in test_metrics:
                if (x['sen'] >= 0.55) & (x['spe'] >= 0.55):
                    test_metrics2.append(x)
                    model_all2.append(model_all[tt])
                tt = tt + 1

        test_auc = [x['auc'] for x in test_metrics2]
        acc, sen, spe, auc, index = comp_avg_metrics(test_metrics2, test_auc, ascend=False, avg_num=1,
                                 name='auc')
        best_model = model_all2[index[0]]

        f[t][0] = sen
        f[t][1] = spe
        obj_f[t] = acc
        AUC_f[t] = auc

        test_preds=test_preds.numpy().tolist()
        test_preds = np.array(test_preds)
        test_preds = test_preds.flatten()
        fin_label_indt.append(test_preds)

        test_labels = test_labels.numpy().tolist()
        test_labels = np.array(test_labels)
        test_labels = test_labels.flatten()
        fin_indt.append(test_labels)

        fin_modelt.append(best_model)

    fin_modelt = np.array(fin_modelt)

    f, ia = np.unique(f, return_index=True, axis=0)
    obj_f = obj_f[ia]
    AUC_f = AUC_f[ia]
    fin_label_ind = fin_label_indt[0][ia]
    fin_ind = fin_indt[0][ia]
    fin_model = fin_modelt[ia]
    NPOP = NPOP[ia]
    return f, obj_f, AUC_f, fin_label_ind, fin_ind, fin_model,  NPOP
