import os
os.environ["CUDA_VISIBLE_DEVICES"] = '1'
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
import numpy as np
import copy


from modeling import ImageClassifier
from heads import get_classification_head
from datasets.ImageNet import SequentialImagenet

from datasets.common import get_dataloader, maybe_dictionarize
from sklearn.linear_model import LinearRegression
from datasets.registry import get_dataset
from task_vectors import TaskVector
from eval import eval_single_dataset
from args import parse_arguments

import sys
sys.path.append('../')

plt.rcParams['figure.dpi'] = 300
plt.rcParams['figure.figsize'] = (7.0, 6.0)
plt.rc('font',family='Times New Roman')
matplotlib.rcParams.update({'font.size':16 })



loss_fn = F.cross_entropy

levels = [0.0, 0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 0.8, 1, 1.1, 1.2, 1.3, 1.4, 1.5, 2.0, 10.0]

def get_params(model):
    param_vectors = [param.view(-1) for param in model.parameters()]
    flattened_params = torch.cat(param_vectors)
    return flattened_params

def get_grads(model):
    param_vectors = []
    for param in model.parameters():
        if param.grad is None:
            param_vectors.append(torch.zeros_like(param).view(-1))
        else:
            param_vectors.append(param.grad.view(-1))

    flattened_params = torch.cat(param_vectors)
    return flattened_params


def plot_loss(test_data, net_0, net_1, net_2, pretrained_model, exam_datasets, classifiers, loss_t, file_name):
    # calculate the coords of pretrained model
    distance_to_nearzero = torch.sqrt(torch.sum((get_params(pretrained_model) - get_params(net_0))**2))
    distance_to_p = torch.sqrt(torch.sum((get_params(pretrained_model) - get_params(net_1))**2))
    distance_to_n = torch.sqrt(torch.sum((get_params(pretrained_model) - get_params(net_2))**2))
    # net_nearzero at (1,0), net_p at (0,1), net_n at (0,0)
    A = torch.tensor([
        [1.0, 0.0],
        [0.0, 1.0],
        [0.0, 0.0]
    ])
    b = torch.tensor([
        distance_to_nearzero,
        distance_to_p,
        distance_to_n
    ])
    coords = torch.linalg.lstsq(A, b).solution
    print("Estimated position of pretrained_model: ", coords)

    x_min, x_max = min(0, coords[0])-0.2, max(1, coords[0])+1.2
    y_min, y_max = min(0, coords[1])-0.2, max(1, coords[1])+1.2
    h = 0.5
    xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
    pos = torch.from_numpy(np.c_[xx.ravel(), yy.ravel()]).float()

    loss = []
    for i in range(pos.shape[0]):
        pos_x, pos_y = pos[i, 0], pos[i, 1]
        net_temp = copy.deepcopy(net_0)

        for net_temp_params, net_0_params, net_1_params, net_2_params in zip(net_temp.parameters(), net_0.parameters(), net_1.parameters(), net_2.parameters()):
            net_temp_weight, net_0_weight, net_1_weight, net_2_weight = net_temp_params.data, net_0_params.data, net_1_params.data, net_2_params.data
            net_temp_params.data = net_2_weight + pos_x * (net_0_weight - net_2_weight) + pos_y * (net_1_params - net_2_weight)

        loss_ = 0.0
        classification_head = classifiers[loss_t]
        data = test_data[loss_t]

        with torch.no_grad():
            inputs, labels = data[0].to('cuda'), data[1].to('cuda')
            feat = net_temp(inputs)
            pred = classification_head(feat)
            loss_ += loss_fn(pred, labels).detach().cpu().item()

        loss.append(loss_)


    z = np.array(loss).reshape(xx.shape)
    print(z)

    plt.contourf(xx, yy, z, alpha=0.3, cmap='RdYlGn_r', extend='both')
    C=plt.contour(xx,yy, z, linewidth=.3, alpha=1.0, extend='both', color='gray')
    plt.clabel(C,inline=True,fontsize=13)

    #plt.scatter(0, 0, c='blue', label='multi_task')
    plt.scatter(1, 0, s=200, c='red', label='Near zero vector')
    # plt.arrow(coords[0], coords[1], 1 - coords[0], 0 - coords[1], head_width=0.05, head_length=0.1, fc='red', ec='red')
    plt.annotate(
        'orthogonal component',
        xy=(1, 0),
        xytext=(coords[0], coords[1]),
        arrowprops=dict(
            facecolor='red',
            shrink=0.15,
            width=1,
            headwidth=10,
            headlength=15,
        )
    )

    plt.scatter(0, 1, s=200, c='green', label='Positive vector')
    # plt.arrow(coords[0], coords[1], 0 - coords[0], 1 - coords[1], head_width=0.05, head_length=0.1, fc='green', ec='green')
    plt.annotate(
        'positive component',
        xy=(0, 1),
        xytext=(coords[0], coords[1]),
        arrowprops=dict(
            facecolor='green',
            shrink=0.15,
            width=1,
            headwidth=10,
            headlength=15,
        )
    )

    plt.scatter(0, 0, s=200, c='purple', label='Negative vector')
    # plt.arrow(coords[0], coords[1], 0 - coords[0], 0 - coords[1], head_width=0.05, head_length=0.1, fc='purple', ec='purple')
    plt.annotate(
        'negative component',
        xy=(0, 0),
        xytext=(coords[0], coords[1]),
        arrowprops=dict(
            facecolor='purple',
            shrink=0.15,
            width=1,
            headwidth=10,
            headlength=15,
        )
    )

    plt.scatter(coords[0], coords[1], s=200, c='blue', label='pretrained extractor')



    plt.legend()
    plt.axis('off')

    plt.savefig(file_name)
    plt.clf()

def TATR_merging(args, task_vectors, pretrained_model, exam_datasets, order=1):
    if args.is_imageNet:
        ntk_dataset = exam_datasets + ['ImageNet']
    else:
        ntk_dataset = exam_datasets

    for pp in pretrained_model.parameters():
        pp.requires_grad = True

    grad_arr, grad_noabs_arr = [], []
    if True:
        exam_dataset = ntk_dataset[target_t]

        classification_head = get_classification_head(args, exam_dataset).to('cuda')
        model = ImageClassifier(pretrained_model, classification_head)

        if exam_dataset == 'ImageNet':
            dataset = SequentialImagenet('/home/sunwenju/dataset/', 32)
            dataloader, _ = dataset.get_data_loaders()
        else:
            dataset = get_dataset(
                exam_dataset,
                model.val_preprocess,
                location=args.data_location,
                batch_size=args.batch_size
            )
            dataloader = get_dataloader(
                dataset, is_train=True, args=args, image_encoder=None)


        # initial grad1
        grad = torch.zeros_like(get_params(pretrained_model))
        grad_noabs = torch.zeros_like(get_params(pretrained_model))

        # cal grad1
        exp_size = 0
        for j, data in enumerate(dataloader):
            inputs, labels = data[0].to('cuda'), data[1].to('cuda')

            if inputs.shape[0] > args.exp_size:
                inputs, labels = inputs[:args.exp_size, ...], labels[:args.exp_size]

            feat = pretrained_model(inputs)
            output = classification_head(feat)
            for i in range(output.size(0)):
                pretrained_model.zero_grad()
                loss = F.cross_entropy(output[i, ...].unsqueeze(0), labels[i].unsqueeze(0))
                loss.backward(retain_graph=True if i < labels.size(0) - 1 else False)
                grad += torch.abs(get_grads(pretrained_model).detach())
                grad_noabs += get_grads(pretrained_model).detach()

            exp_size += inputs.shape[0]
            if exp_size >= args.exp_size:
                break
        grad = (grad / exp_size).to('cpu')
        grad_noabs = (grad_noabs / exp_size).to('cpu')

        grad_arr.append(grad)
        grad_noabs_arr.append(grad_noabs)

    Omega = torch.zeros_like(
        torch.cat([task_vectors[0].vector[name].view(-1) for name, param in pretrained_model.named_parameters()])
    )
    Omega_noabs = torch.zeros_like(
        torch.cat([task_vectors[0].vector[name].view(-1) for name, param in pretrained_model.named_parameters()])
    )

    vector1 = grad_arr[0]
    for j in range(len(task_vectors)):
        if target_t != j:
            vector2 = torch.abs(torch.cat([task_vectors[j].vector[name].view(-1) for name, param in pretrained_model.named_parameters()]))
            Omega += vector1 * vector2

            Omega_noabs += grad_noabs_arr[0] * torch.cat([task_vectors[j].vector[name].view(-1) for name, param in pretrained_model.named_parameters()])


    num_elements = Omega.numel()
    num_top = int(args.ratio * num_elements)

    k_value = num_elements - num_top
    values, indices = torch.flatten(Omega).sort(descending=True)
    threshold = values[k_value]

    mask = (Omega < threshold).int()

    return Omega, Omega_noabs, mask



exam_datasets = ['SUN397', 'Cars', 'RESISC45', 'EuroSAT', 'SVHN', 'GTSRB', 'MNIST', 'DTD'] # SUN397 | Cars | RESISC45 | EuroSAT | SVHN | GTSRB | MNIST | DTD
model = 'ViT-B-32'
args = parse_arguments()
args.data_location = '../data'
args.model = model
args.device = 'cuda'
args.save = '../checkpoints/checkpoints/' + model
args.exp_size = 100
args.ratio = 0.99
args.logs_path = '../logs/' + model
pretrained_checkpoint = '../checkpoints/checkpoints/'+model+'/zeroshot.pt'

for conf in [0.95, 0.98, 0.99, 0.995, 0.998, 0.999]:
    print('################################################################')
    print('######################### Merging :', conf, ' ##############################')
    print('################################################################')
    args.ratio = conf

    target_t = 1

    # 33version
    task_vectors = [
        TaskVector(pretrained_checkpoint, '../checkpoints/checkpoints/'+model+'/'+dataset_name+'/finetuned.pt') for dataset_name in exam_datasets
    ]
    args.is_imageNet = False
    pretrained_model = torch.load(pretrained_checkpoint).to('cuda')
    NTK, NTK_noabs, mask = TATR_merging(args, task_vectors, pretrained_model, exam_datasets)

    masked_positions = mask == 0

    positive_mask = (NTK_noabs > 0.0).int() * masked_positions
    negative_mask = (NTK_noabs < 0.0).int() * masked_positions

    task_vector_sum = sum([task_vectors[i] for i in range(len(task_vectors)) if i != target_t])

    task_vector_near_zero = copy.deepcopy(task_vector_sum)
    progress = 0
    for name, pp in list(pretrained_model.named_parameters()):
        mask_params = mask[progress: progress + int(torch.tensor(pp.size()).prod().item())].view(pp.size())
        progress += int(torch.tensor(pp.size()).prod().item())
        task_vector_near_zero.vector[name] *= mask_params
    net_nearzero = task_vector_near_zero.apply_to(pretrained_checkpoint, scaling_coef=0.3).to('cuda')

    task_vector_positive = copy.deepcopy(task_vector_sum)
    progress = 0
    for name, pp in list(pretrained_model.named_parameters()):
        mask_params = positive_mask[progress: progress + int(torch.tensor(pp.size()).prod().item())].view(pp.size())
        progress += int(torch.tensor(pp.size()).prod().item())
        task_vector_positive.vector[name] *= mask_params
    net_p = task_vector_positive.apply_to(pretrained_checkpoint, scaling_coef=0.3).to('cuda')

    task_vector_negative = copy.deepcopy(task_vector_sum)
    progress = 0
    for name, pp in list(pretrained_model.named_parameters()):
        mask_params = negative_mask[progress: progress + int(torch.tensor(pp.size()).prod().item())].view(pp.size())
        progress += int(torch.tensor(pp.size()).prod().item())
        task_vector_negative.vector[name] *= mask_params
    net_n = task_vector_negative.apply_to(pretrained_checkpoint, scaling_coef=0.3).to('cuda')

    test_data = []
    classifiers = []
    for exam_dataset in exam_datasets:
        classification_head = get_classification_head(args, exam_dataset).to('cuda')
        model___ = ImageClassifier(pretrained_model, classification_head)
        dataset = get_dataset(
                        exam_dataset,
                        model___.val_preprocess,
                        location=args.data_location,
                        batch_size=args.batch_size
                    )
        dataloader = get_dataloader(dataset, is_train=True, args=args, image_encoder=None)
        for data in dataloader:
            test_data.append(data)
            classifiers.append(classification_head)
            break


    ######### CIFAR100
    plot_loss(test_data, net_nearzero, net_p, net_n, pretrained_model, exam_datasets, classifiers, loss_t=target_t, file_name='img/loss_landscape_specific_task_'+ str(target_t) +'_ratio_' + str(args.ratio) + '.jpg')









