#%%
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from naslib.search_spaces import NasBench201SearchSpace
from naslib.utils import get_dataset_api
from naslib.search_spaces.core import Metric
import numpy as np
from naslib.utils import get_dataset_api
from naslib.search_spaces.nasbench201.conversions import *
from torch.optim import lr_scheduler


def Trainer(arg, net, train_loader, val_loader, device, criterion_type='ce', scheduler_flag = True):
    val_accuracy = 0.

    # %% ========= setting ==========
    # Use the cross entropy loss function in the neural network toolbox nn
    if criterion_type =='bce':
        criterion = nn.BCEWithLogitsLoss()
    elif criterion_type =='ce':
        criterion = nn.CrossEntropyLoss()
    # Use SGD (stochastic gradient descent) optimization, learning rate is 0.001, momentum is 0.9

    # optimizer = optim.SGD(net.parameters(), lr=arg.lr, momentum=arg.momentum)
    optimizer = optim.Adam(net.parameters(), lr=arg.lr)
    if scheduler_flag:
        # scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=5)
        
        scheduler = lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
    # optimizer_ft = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

    # Decay LR by a factor of 0.1 every 7 epochs
    # exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

    for epoch in range(arg.num_epoch):  # Specify how many epochs to cycle through the training
        # set to the eval mode to fix the paramaters of batchnorm
        # print('==========Epoch: %d ==========' % epoch)
        net.train()
        sum_loss = 0.0
        correct = 0.0
        total = 0.0

        for i, (inputs, targets) in enumerate(train_loader):
            # print('==========Iter: %d ==========' % i)
            inputs, targets = inputs.to(device).float(), targets.to(device).float()

            # Initialize the grad value of the parameter to
            optimizer.zero_grad()
            # forward + backward + optimize
            outputs = net(inputs).squeeze()
            # Use cross entropy to calculate loss for output and labels
            loss = criterion(outputs, targets)
            # print('loss \n', loss)
            # Backpropagation
            loss.backward()
            optimizer.step()

            sum_loss += loss.item()

            if criterion_type =='bce':
                predicted = (outputs>0.5).float()
            elif criterion_type =='ce':
                _, predicted = torch.max(outputs.data, 1)

            
            total += targets.size(0)  # Update the number of test pictures
            correct += (predicted == targets).sum()  # Update the number of correctly classified pictures

        if scheduler_flag:
            scheduler.step()
        print(
            '==========================[epoch:%d] Loss: %.03f | Acc: %.3f%% '
            % (epoch + 1, loss.item(), 100. * correct / total))

        # ========================================
        #               Validation
        # ========================================
        # After the completion of each training epoch, measure our performance on
        # our validation set.

        acc = inference(net, device, val_loader, criterion_type=criterion_type)
        print(f'Val: | Acc: {acc:.5f}')

        if acc > val_accuracy:
            val_accuracy = acc
            val_iter = 0
            torch.save(net.state_dict(), arg.PATH)

        else:
            val_iter = val_iter + 1
        if val_iter == arg.max_val_iter:
            print("Validation accuracy did not improve for the last {} validation runs. Early stopping..."
                  .format(arg.max_val_iter))
            break


def inference(net, device, testloader, criterion_type='ce',):
    net.to(device)
    net.eval()
    correct = 0
    total = 0

    if criterion_type =='bce':
        criterion = nn.BCEWithLogitsLoss()
    elif criterion_type =='ce':
        criterion = nn.CrossEntropyLoss()
    with torch.no_grad():
        for i, (inputs, targets) in enumerate(testloader):
            # print('in',inputs.shape)
            inputs, targets = inputs.to(device).float(), targets.to(device).float()
            outputs = net(inputs).squeeze()
            loss = criterion(outputs, targets)

            if criterion_type =='bce':
                predicted = (outputs>0.5).float()
            elif criterion_type =='ce':
                _, predicted = torch.max(outputs.data, 1)

            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    acc = correct / total
    # print('inf_loss',loss)
    # print('total_inf',total)
    # print('target = %s \n, predicted = %s'%(targets,predicted))
    return acc

def sort_arch(popu,pred_s, pred_t,device):
    
    # Create pairwise comparison matrix
    comp_matrix_s = torch.zeros(len(popu), len(popu), dtype=torch.int)
    comp_matrix_t = torch.zeros(len(popu), len(popu), dtype=torch.int)
    for i in range(len(popu)):
        for j in range(i + 1, len(popu)):
            arch1_encode = popu[i].arch.encode()
            arch2_encode = popu[j].arch.encode()
            # print('encode',arch1_encode)
            pair = torch.tensor(arch1_encode+arch2_encode).float().to(device)
             
            # print('cuda',pair.dtype)
            if pred_s(pair): # pred will output 1 if i is greater than j
                comp_matrix_s[i, j] = 1
            else:
                comp_matrix_s[j, i] = 1
            if pred_t(pair): # pred will output 1 if i is greater than j
                comp_matrix_t[i, j] = 1
            else:
                comp_matrix_t[j, i] = 1

    C = 1
    adj_matrix_s = C - C * comp_matrix_s
    adj_matrix_t = C - C * comp_matrix_t
    # Sort the items using the topological sorting algorithm
    sorted_indices_s = torch.topk(adj_matrix_s.sum(dim=0), k=len(popu)).indices
    sorted_indices_t = torch.topk(adj_matrix_t.sum(dim=0), k=len(popu)).indices

    return sorted_indices_s.tolist(), sorted_indices_t.tolist()


def sort_arch_emd(popu, pred, emd_list, device):
    
    # Create pairwise comparison matrix
    comp_matrix_s = torch.zeros(len(popu), len(popu), dtype=torch.int)
    comp_matrix_t = torch.zeros(len(popu), len(popu), dtype=torch.int)
    for i in range(len(popu)):
        for j in range(i + 1, len(popu)):
            arch1_encode = popu[i].arch.encode()
            arch2_encode = popu[j].arch.encode()
            # print('encode',arch1_encode)
            arch_pair = arch1_encode+arch2_encode
            pair_src = torch.tensor(arch_pair+emd_list[0]).float().to(device)
        
             
            # print('cuda',pair.dtype)
            if pred(pair_src): # pred will output 1 if i is greater than j
                comp_matrix_s[i, j] = 1
            else:
                comp_matrix_s[j, i] = 1
            for emd in emd_list[1:]:
                pair_ti = torch.tensor(arch_pair+emd).float().to(device)
                if pred(pair_ti): # pred will output 1 if i is greater than j
                    comp_matrix_t[i, j] += 1
                else:
                    comp_matrix_t[j, i] += 1

    C = 1
    adj_matrix_s = C - C * comp_matrix_s
    adj_matrix_t = C - C * comp_matrix_t
    # Sort the items using the topological sorting algorithm
    sorted_indices_s = torch.topk(adj_matrix_s.sum(dim=0), k=len(popu)).indices
    sorted_indices_t = torch.topk(adj_matrix_t.sum(dim=0), k=len(popu)).indices

    return sorted_indices_s.tolist(), sorted_indices_t.tolist()

def generate_embedding(source, similarity):
    # Normalize the source vector
    source_norm = np.linalg.norm(source)
    source_normalized = source / source_norm

    # Compute the angle between the source and target vectors
    theta = np.arccos(similarity)

    # Compute the length of the target vector
    length_target = similarity * source_norm

    # Generate a random orthogonal vector to the source vector
    ortho = np.random.randn(*source.shape)
    ortho -= np.dot(ortho, source_normalized) * source_normalized
    ortho_norm = np.linalg.norm(ortho)
    ortho_normalized = ortho / ortho_norm

    # Compute the target vector using the angle and length
    target = np.cos(theta) * length_target * source_normalized + np.sin(theta) * length_target * ortho_normalized

    return target

def arch2res(search_space, dataset, arch_en):
    benchmark_api = get_dataset_api(search_space=search_space, dataset=dataset)
    # benchmark_api = get_dataset_api(search_space='nasbench201', dataset='cifar10')
    # num = 1000

   
    arch_str = convert_op_indices_to_str(arch_en)
    # cf10 = benchmark_cf10["nb201_data"][arch_str]
    # cf100 = benchmark_cf100["nb201_data"][arch_str]
    # cf10_acc = cf10['cifar10-valid']["eval_acc1es"][-1]
    # # print(query_results['cifar100']["eval_acc1es"][-1])
    # cf100_acc = cf100['cifar100']["eval_acc1es"][-1]
    # cf10_val_acc.append(cf10_acc)
    # cf100_val_acc.append(cf100_acc)
    img = benchmark_api["nb201_data"][arch_str]
    if dataset == 'cifar10':
        acc = img['cifar10-valid']["eval_acc1es"][-1]
    else:
        acc = img[dataset]["eval_acc1es"][-1]
    return acc