#!/usr/bin/python
# -*- coding: UTF-8 -*-

import numpy as np
from computing_fitnesss_ini import ini
from FNDSF import FNDSF1
from BTSwCDf import BTSwCDf1
from Clonef import Clonef1
from PNmutation import PNmutation1
from computing_fitness_utility_new import computing_fitness_utility_new1
from UPPNSGAIIf import UPPNSGAIIf1
from cal_weight import cal_weight1
from cal_single_probability import cal_single_probability1
from calculating_reliability_new import calculating_reliability_new1
from ERrule import Analytic_ER_rule
from cal_evaluation import cal_evaluation1
from compute_measures import compute_measures1
from Confidence_Calibration import Confidence_Calibration_Calculation
from Entropy import Entropy_Calculation
from torchvision import datasets, models, transforms
import torchvision
import torch
import torch.nn as nn
from torch.autograd import Variable
from itertools import chain
import os

def objective4(params): 

    """hyperparameter"""
    ftxt = open('result.txt', 'a')
    lamda = params['lamda'] 
    mut_rate = params['mut_rate']  
    min_model = params['min_model']  
    Slamda = 'lamda=' + str(lamda)  
    Smut_rate = ' ' + 'mut_rate=' + str(mut_rate)
    Smin_model = ' ' + 'min_model=' + str(min_model)
    ftxt.write(Slamda)
    ftxt.write(Smut_rate)
    ftxt.write(Smin_model)

    """hyperparameter"""
    batchsize = 32
    epoch_num = 30
    LR = 0.01
    image_size = 224 
    patch_size = 16 
    layer_number = [[2], [3], [4]] 
    embed_dim = [[round(i*(patch_size**2))] for i in [1, 1.2, 1.4, 1.6]] 
    ds = [[round(i*((image_size/patch_size)**2))] for i in [2, 3, 4, 5]]
    dc = [[round(i*(patch_size**2))] for i in [2, 4, 6, 8, 10]]

    """data"""
    train_transform = transforms.Compose([transforms.Grayscale(),
                                          transforms.Resize([image_size, image_size]),
                                          transforms.RandomRotation(20),
                                          transforms.RandomHorizontalFlip(),
                                          transforms.RandomVerticalFlip(),
                                          transforms.ToTensor(),
                                          transforms.Normalize([0.456], [0.225])])
    transform = transforms.Compose([transforms.Grayscale(),
                                    transforms.Resize([image_size, image_size]),
                                    transforms.ToTensor(),
                                    transforms.Normalize([0.456], [0.225])])

    train_path = r"/home/ljh/Project_BayOpt_MLP/COVID-CT/train/"
    trainset = torchvision.datasets.ImageFolder(train_path, transform=train_transform)
    train_load = torch.utils.data.DataLoader(trainset, batch_size=batchsize, shuffle=True)
    test_path = r"/home/ljh/Project_BayOpt_MLP/COVID-CT/val/"
    testset = torchvision.datasets.ImageFolder(test_path, transform=transform)
    test_load = torch.utils.data.DataLoader(testset, batch_size=batchsize, shuffle=False)

    """optimize"""
    pop_num = 15 
    it_num = 5
    CS = 20

    layer_number_par1 = [layer_number[i] for i in np.random.randint(0, 3, size=pop_num)]  
    embed_dim_par3 = [embed_dim[i] for i in np.random.randint(0, 4, size=pop_num)] 
    ds_par4 = [ds[i] for i in np.random.randint(0, 4, size=pop_num)] 
    dc_par5 = [dc[i] for i in np.random.randint(0, 5, size=pop_num)] 

    EPOP = np.concatenate((layer_number_par1,embed_dim_par3,ds_par4,dc_par5), axis=1)

    it0 = 0
    pa, f_obj, AUC_f = ini(it0, EPOP, train_load, test_load, epoch_num, LR, image_size, patch_size, trainset, testset) 
    FL = FNDSF1(pa)
    FL = np.array(FL)
    ClonePOP = EPOP
    Clonepa = pa
    Epal = np.zeros((len(pa), 3))
    Epal[:, 0:2] = pa
    Epal[:, 2] = FL
    final_P = []
    final_ind = []
    P_fin = []
    pred_fin = []
    for it in range(0, it_num): 
        cloneover = []
        crowding_distance_values = []
        ClonePOP = BTSwCDf1(EPOP, Epal)
        cloneover = Clonef1(ClonePOP, Clonepa, CS)
        cloneover = PNmutation1(cloneover, mut_rate, patch_size) 
        NPOP = np.r_[EPOP, cloneover]
        Npa, f_obj, AUC_f, f_label_ind, f_ind, f_model,  NPOP = computing_fitness_utility_new1(it, NPOP, train_load, test_load, epoch_num, LR, image_size, patch_size, trainset, testset)
        while len(NPOP) < pop_num: 
            cloneover = PNmutation1(cloneover, mut_rate)
            Npa1, f_obj1, AUC_f1, f_label_ind1, f_ind1, f_model1,  NPOP1 = computing_fitness_utility_new1(it, cloneover, train_load, test_load, epoch_num, LR, image_size, patch_size, trainset, testset)
            Npa = np.r_[Npa, Npa1] 
            f_obj = np.r_[f_obj, f_obj1]
            AUC_f = np.r_[AUC_f, AUC_f1]
            f_label_ind = np.r_[f_label_ind, f_label_ind1]
            f_ind = np.r_[f_ind, f_ind1]
            f_model = np.r_[f_model, f_model1]
            NPOP = np.r_[NPOP, NPOP1]

            Npa, ia = np.unique(Npa, return_index=True, axis=0) 
            f_obj = f_obj[ia]
            AUC_f = AUC_f[ia]
            f_label_ind = f_label_ind[ia]
            f_ind = f_ind[ia]
            f_model = f_model[ia]
            NPOP = NPOP[ia]

        FL = FNDSF1(Npa) 
        FL = np.array(FL)
        EPOP, Epal, E_f, E_AUC, E_label_ind, E_ind, E_model = UPPNSGAIIf1(NPOP, Npa, FL, pop_num, f_obj, AUC_f,
                                                                          f_label_ind, f_ind, f_model) 
        Epal = np.array(Epal)
        Clonepa = Epal[:, 0:2]
    idx = [k for k, x in enumerate(Epal[:, 2]) if x == 0]
    pareto = Epal[idx, 0:2]
    EPOP = np.array(EPOP)
    fin_POP = EPOP[idx]
    E_f = np.array(E_f)
    fin_f = E_f[idx]
    E_AUC = np.array(E_AUC)
    fin_AUC = E_AUC[idx]
    E_label_ind = np.array(E_label_ind)
    fin_label_ind = E_label_ind[idx]
    E_ind = np.array(E_ind)
    fin_ind = E_ind[idx]
    E_model = np.array(E_model)
    fin_model = E_model[idx]
    tmp_label = np.ones((len(pareto), 1))
    m = 0
    for t in range(0, len(pareto)):
        if pareto[t, 0] == 0 or pareto[t, 1] == 0:
            tmp_label[t] = 0
        if pareto[t, 0] / pareto[t, 1] < 0.5 or pareto[t, 0] / pareto[t, 1] > 1.5:
            tmp_label[t] = 0
    idx1 = [k for k, x in enumerate(tmp_label) if x == 1]
    pareto = pareto[idx1]
    fin_POP = fin_POP[idx1]
    fin_f = fin_f[idx1]
    fin_AUC = fin_AUC[idx1]
    fin_label_ind = fin_label_ind[idx1]
    fin_ind = fin_ind[idx1]
    fin_model = fin_model[idx1]
    while len(pareto) < min_model:
        m = m + 1
        idxm = [k for k, x in enumerate(Epal[:, 2]) if x == m]
        pareto = np.r_[pareto, Epal[idxm, 0:2]] 
        fin_POP = np.r_[fin_POP, EPOP[idxm]]
        fin_f = np.r_[fin_f, E_f[idxm]]
        fin_AUC = np.r_[fin_AUC, E_AUC[idxm]]
        fin_label_ind = np.r_[fin_label_ind, E_label_ind[idxm]]
        fin_ind = np.r_[fin_ind, E_ind[idxm]]
        fin_model = np.r_[fin_model, E_model[idxm]]
        tmp_label = np.ones((len(pareto), 1))
        for t in range(0, len(pareto)):
            if pareto[t, 0] == 0 or pareto[t, 1] == 0:
                tmp_label[t] = 0
            if pareto[t, 0] / pareto[t, 1] < 0.5 or pareto[t, 0] / pareto[t, 1] > 1.5:
                tmp_label[t] = 0
        idx2 = [k for k, x in enumerate(tmp_label) if x == 1]
        pareto = pareto[idx2]
        fin_POP = fin_POP[idx2]
        fin_f = fin_f[idx2]
        fin_AUC = fin_AUC[idx2]
        fin_label_ind = fin_label_ind[idx2]
        fin_ind = fin_ind[idx2]
        fin_model = fin_model[idx2]

    print('\n' 'pareto:'+'\n')
    print(pareto)
    print('\n' + 'auc:' + '\n')
    print(fin_AUC)
    print('\n' + 'acc:' + '\n')
    print(fin_f)
    w = cal_weight1(pareto, fin_AUC, lamda)
    print('权重:')
    print(w)


    test_load = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=True)
    os.environ['CUDA_VISIBLE_DEVICES']='4'
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    test_labels = torch.zeros(len(testset), 1)
    test_preds = torch.zeros(len(testset), 1)
    test_probs = torch.zeros(len(testset), 1)
    index = 0
    with torch.no_grad():
        for X, y in test_load:
            X, y = Variable(X.to(device)), Variable(y.to(device))
            p = cal_single_probability1(X, fin_POP, fin_model) 
            r = calculating_reliability_new1(p)
            probs = Analytic_ER_rule(r, w, p)
            probs = list(chain.from_iterable(probs))
            probs = torch.tensor(probs)
            _, pred_y = torch.max(probs, 0)
            test_labels[index:index + len(y), 0] = y
            test_preds[index:index + len(y), 0] = pred_y
            test_probs[index:index + len(y), 0] = probs[1]
            
            index += len(y)

    SEN, SPE, AUC, ACC = compute_measures1(test_labels, test_preds, test_probs)
    M = 10
    ECE, MCE = Confidence_Calibration_Calculation(test_probs, test_labels, M) 
    E_mean, E_std = Entropy_Calculation(test_probs, test_labels)

    SEN = round(SEN,4)
    SPE = round(SPE, 4)
    AUC = round(AUC, 4)
    ACC = round(ACC, 4)
    ECE = round(ECE, 4)
    MCE = round(MCE, 4)
    E_mean = round(E_mean, 4)
    E_std = round(E_std, 4)
    SSEN = ' ' + 'SEN=' + str(SEN)
    SSPE = ' ' + 'SPE=' + str(SPE)
    SAUC = ' ' + 'AUC=' + str(AUC)
    SACC = ' ' + 'ACC=' + str(ACC)
    SECE = ' ' + 'ECE=' + str(ECE)
    SMCE = ' ' + 'MCE=' + str(MCE)
    SE_mean = ' ' + 'E_mean=' + str(E_mean)
    SE_std = ' ' + 'E_std=' + str(E_std)+ '\n'

    print(SSEN, SSPE, SAUC, SACC, SECE, SMCE, SE_mean, SE_std, fin_POP)

    
    ftxt.write(SSEN)
    ftxt.write(SSPE)
    ftxt.write(SAUC)
    ftxt.write(SACC)
    ftxt.write(SECE)
    ftxt.write(SMCE)
    ftxt.write(SE_mean)
    ftxt.write(SE_std)
    ftxt.write(' ' + 'POP=' + '\n' + str(fin_POP) + '\n')
    ftxt.write(' ' + 'pareto=' + '\n' + str(pareto) + '\n')
    ftxt.write(' ' + 'auc=' + '\n' + str(fin_AUC) + '\n')
    ftxt.write(' ' + 'acc=' + '\n' + str(fin_f) + '\n')
    ftxt.write(' ' + '权重=' + '\n' + str(w) + '\n')
    
    ftxt.write('\n')
    ftxt.close()
    f = 1 - AUC
    return f
