import os


os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import torch
import torch.nn.functional as F


import numpy as np


import time
import sys
sys.path.append('../')

from task_vectors import TaskVector
from eval import eval_single_dataset
from args import parse_arguments

from modeling import ImageClassifier
from heads import get_classification_head
from datasets.ImageNet import SequentialImagenet

from datasets.common import get_dataloader, maybe_dictionarize
from datasets.registry import get_dataset

import utils


def create_log_dir(path, filename='log.txt'):
    import logging
    if not os.path.exists(path):
        os.makedirs(path)
    logger = logging.getLogger(path)
    logger.setLevel(logging.DEBUG)
    fh = logging.FileHandler(path+'/'+filename)
    fh.setLevel(logging.DEBUG)
    ch = logging.StreamHandler()
    ch.setLevel(logging.DEBUG)
    logger.addHandler(fh)
    logger.addHandler(ch)
    return logger

def get_params(model):
    param_vectors = [param.view(-1) for param in model.parameters()]
    flattened_params = torch.cat(param_vectors)
    return flattened_params

def get_grads(model):
    param_vectors = []
    for param in model.parameters():
        if param.grad is None:
            param_vectors.append(torch.zeros_like(param).view(-1))
        else:
            param_vectors.append(param.grad.view(-1))

    flattened_params = torch.cat(param_vectors)
    return flattened_params


def TATR_merging(args, task_vectors, pretrained_model, exam_datasets):

    Omega = torch.zeros_like(
        torch.cat([task_vectors[0].vector[name].view(-1) for name, param in pretrained_model.named_parameters()])
    )

    for i in range(len(task_vectors)):
        for j in range(len(task_vectors)):
            if i != j:
                vector1 = torch.abs(torch.cat([task_vectors[i].vector[name].view(-1) for name, param in pretrained_model.named_parameters()]))
                vector2 = torch.abs(torch.cat([task_vectors[j].vector[name].view(-1) for name, param in pretrained_model.named_parameters()]))
                Omega += vector1 * vector2


    num_elements = Omega.numel()
    num_top = int(args.ratio * num_elements)

    k_value = num_elements - num_top
    values, indices = torch.flatten(Omega).sort(descending=True)
    threshold = values[k_value]

    mask = (Omega < threshold).int()


    for task_vector in task_vectors:
        progress = 0
        for name, pp in list(pretrained_model.named_parameters()):
            mask_params = mask[progress: progress + int(torch.tensor(pp.size()).prod().item())].view(pp.size())
            progress += int(torch.tensor(pp.size()).prod().item())
            task_vector.vector[name] *= mask_params

    task_vector_sum = sum(task_vectors)
    return task_vector_sum


# Loading


exam_datasets = ['SUN397', 'Cars', 'RESISC45', 'EuroSAT', 'SVHN', 'GTSRB', 'MNIST', 'DTD']  # SUN397 | Cars | RESISC45 | EuroSAT | SVHN | GTSRB | MNIST | DTD
test_datasets = ['SUN397', 'Cars', 'RESISC45', 'EuroSAT', 'SVHN', 'GTSRB', 'MNIST','DTD']  # SUN397 | Cars | RESISC45 | EuroSAT | SVHN | GTSRB | MNIST | DTD
model = 'ViT-B-32'
args = parse_arguments()
args.data_location = '../data'
args.model = model
args.device = 'cuda'
args.save = '../checkpoints/checkpoints/' + model
args.logs_path = '../logs/' + model
pretrained_checkpoint = '../checkpoints/checkpoints/' + model + '/zeroshot.pt'

args.ratio = 0.99  # opposite to the paper
args.exp_size = 128
args.is_imageNet = False
str_time_ = time.strftime('%Y%m%d_%H%M%S', time.localtime(time.time()))
log = create_log_dir(args.logs_path, 'log_{}_task_arithmetic.txt'.format(str_time_))
for conf in [0.95, 0.98, 0.99, 0.995, 0.998, 0.995]:
    task_vectors = [
        TaskVector(pretrained_checkpoint, '../checkpoints/checkpoints/' + model + '/' + dataset_name + '/finetuned.pt')
        for dataset_name in exam_datasets
    ]
    pretrained_model = torch.load(pretrained_checkpoint).to('cuda')



    print('################################################################')
    print('######################### Merging :', conf, ' ##############################')
    print('################################################################')


    args.ratio = conf

    task_vector_sum = TATR_merging(args, task_vectors, pretrained_model, exam_datasets)

    ################################################################
    ######################### Testing ##############################
    ################################################################

    scaling_coef_ = 0.3

    image_encoder = task_vector_sum.apply_to(pretrained_checkpoint, scaling_coef=scaling_coef_)
    log.info('*'*20 + 'scaling_coef:' + str(scaling_coef_) + '*'*20)

    accs = []
    for dataset in exam_datasets:
        metrics = eval_single_dataset(image_encoder, dataset, args)
        log.info(str(dataset) + ':' + str(metrics.get('top1')*100)+'%')
        accs.append(metrics.get('top1')*100)
    log.info('Avg ACC:' + str(np.mean(accs)) + '%')





