import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

from tatr_merging_utils import NTK_merging, TATR_merging, TATR_mergingnn         # 33version


from task_vectors import TaskVector
import time
import pickle
import sys
sys.path.append('../')

from eval import eval_single_dataset
from args import parse_arguments



def create_log_dir(path, filename='log.txt'):
    import logging
    if not os.path.exists(path):
        os.makedirs(path)
    logger = logging.getLogger(path)
    logger.setLevel(logging.DEBUG)
    fh = logging.FileHandler(path+'/'+filename)
    fh.setLevel(logging.DEBUG)
    ch = logging.StreamHandler()
    ch.setLevel(logging.DEBUG)
    logger.addHandler(fh)
    logger.addHandler(ch)
    return logger

exam_datasets = ['SUN397', 'Cars', 'RESISC45', 'EuroSAT', 'SVHN', 'GTSRB', 'MNIST', 'DTD'] # SUN397 | Cars | RESISC45 | EuroSAT | SVHN | GTSRB | MNIST | DTD
model = 'ViT-B-32'
args = parse_arguments()
args.data_location = '../data'
args.model = model
args.save = '../checkpoints/checkpoints/' + model
args.logs_path = '../logs/' + model
args.ratio = 0.99
args.exp_size = 128
pretrained_checkpoint = '../checkpoints/checkpoints/'+model+'/zeroshot.pt'

for conf in [0.95, 0.98, 0.99, 0.995, 0.998, 0.995]:
    print('################################################################')
    print('######################### Merging :', conf, ' ##############################')
    print('################################################################')
    args.ratio = conf

    str_time_ = time.strftime('%Y%m%d_%H%M%S', time.localtime(time.time()))
    log = create_log_dir(args.logs_path, 'log_{}_ties_merging.txt'.format(str_time_))

    from ties_merging_utils import *
    ft_checks = []
    for dataset_name in exam_datasets:
        file_name = '../checkpoints/checkpoints/'+model+'/'+dataset_name+'/finetuned.pt'
        if file_name == '../checkpoints/checkpoints/' + 'ViT-B-16' + '/' + 'Cars' + '/finetuned.pt':
            ft_checks.append(pickle.load(open(file_name, 'rb')).state_dict())
        else:
            ft_checks.append(torch.load(file_name).state_dict())

    ptm_check = torch.load(pretrained_checkpoint).state_dict()
    check_parameterNamesMatch(ft_checks + [ptm_check])

    # 23version
    # mask = NTK_merging(args, pretrained_checkpoint, exam_datasets, is_imageNet=False)

    # 33version
    task_vectors = [
        TaskVector(pretrained_checkpoint, '../checkpoints/checkpoints/'+model+'/'+dataset_name+'/finetuned.pt') for dataset_name in exam_datasets
    ]
    args.is_imageNet = False
    pretrained_model = torch.load(pretrained_checkpoint).to('cuda')
    mask = TATR_merging(args, task_vectors, pretrained_model, exam_datasets)

    # nnversion
    # task_vectors = [
    #     TaskVector(pretrained_checkpoint, '../checkpoints/checkpoints/'+model+'/'+dataset_name+'/finetuned.pt') for dataset_name in exam_datasets
    # ]
    # args.is_imageNet = False
    # pretrained_model = torch.load(pretrained_checkpoint).to('cuda')
    # mask = TATR_mergingnn(args, task_vectors, pretrained_model)

    remove_keys = []
    print(f"Flattening out Checkpoints")
    flat_ft = torch.vstack([state_dict_to_vector(check, remove_keys) for check in ft_checks])
    flat_ptm = state_dict_to_vector(ptm_check, remove_keys)

    tv_flat_checks = flat_ft - flat_ptm
    assert check_state_dicts_equal(vector_to_state_dict(flat_ptm, ptm_check, remove_keys), ptm_check)
    assert all([check_state_dicts_equal(vector_to_state_dict(flat_ft[i], ptm_check, remove_keys), ft_checks[i])for i in range(len(ft_checks))])

    K = 20
    merge_func = "dis-sum"
    scaling_coef_ = 0.3

    merged_tv = ties_merging(tv_flat_checks, reset_thresh=K, merge_func=merge_func,)
    merged_check = flat_ptm + scaling_coef_ * (mask * merged_tv)
    merged_state_dict = vector_to_state_dict(merged_check, ptm_check, remove_keys=remove_keys)

    image_encoder = torch.load(pretrained_checkpoint)
    image_encoder.load_state_dict(merged_state_dict, strict=False)

    Total_ACC = 0.
    for dataset in exam_datasets:
        metrics = eval_single_dataset(image_encoder, dataset, args)
        Total_ACC += metrics['top1']
        log.info(str(dataset) + ':' + str(metrics))

    log.info('Final: ' + 'Avg ACC:' + str(Total_ACC / len(exam_datasets)))
