import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import sys

src_root_path = '../'
checkpoint_path = '../checkpoints/checkpoints/'
dataset_path = '../data/'
sys.path.append('../')

import math
import time
import tqdm
import torch
from task_vectors import TaskVector
from args import parse_arguments
import pickle
from tatr_merging_utils import NTK_merging, TATR_merging, TATR_mergingnn         # 33version

from eval import eval_single_dataset_preprocess_mapping_head
from datasets.registry import get_dataset
from datasets.common import maybe_dictionarize, get_dataloader_shuffle
from merging_model import ModelWrapper, AlphaWrapper, make_functional


def create_log_dir(path, filename='log.txt'):
    import logging
    if not os.path.exists(path):
        os.makedirs(path)
    logger = logging.getLogger(path)
    logger.setLevel(logging.DEBUG)
    fh = logging.FileHandler(path + '/' + filename)
    fh.setLevel(logging.DEBUG)
    ch = logging.StreamHandler()
    ch.setLevel(logging.DEBUG)
    logger.addHandler(fh)
    logger.addHandler(ch)
    return logger

args = parse_arguments()
# Config
exam_datasets = ['SUN397', 'Cars', 'RESISC45', 'EuroSAT', 'SVHN', 'GTSRB', 'MNIST',
                 'DTD']  # SUN397 | Cars | RESISC45 | EuroSAT | SVHN | GTSRB | MNIST | DTD
learn_datasets = ['SUN397', 'Cars', 'RESISC45', 'EuroSAT', 'SVHN', 'GTSRB', 'MNIST', 'DTD']

method_name = 'task_arithmetic'  # choose: weight_averaging | task_arithmetic | ties_merging | tw_adamerging | lw_adamerging | tw_adamergingpp | lw_adamergingpp
model_name = 'ViT-B-32'  # choose: ViT-B-32 | ViT-B-16 | ViT-L-14
batch_size = 16
# args.ratio = 0.99  # ntk33
args.ratio = 0.98  # ntk23

# model_name = 'ViT-L-14'  #  choose: ViT-B-32 | ViT-B-16 | ViT-L-14
# batch_size = 4
# args.batch_size = 16
# args.ratio = 0.995

iterations = 500
eval_iterations = 100


args.layers = 2
args.method_name = method_name
args.model_name = model_name
args.data_location = dataset_path
args.save = checkpoint_path + model_name
args.exp_size = 128

pretrained_checkpoint = checkpoint_path + model_name + '/zeroshot.pt'

args.logs_path = src_root_path + 'src/logs/layers/' + model_name
log = create_log_dir(args.logs_path,
                     'log_{}_{}_{}_{}.txt'.format(str(__file__.split("/")[-1].split(".")[0]), method_name, model_name,
                                                  time.strftime('%Y%m%d_%H%M%S', time.localtime(time.time()))))
log.info(str(args))

for conf in [0.95, 0.98, 0.99, 0.995, 0.998, 0.995]:
    print('################################################################')
    print('######################### Merging :', conf, ' ##############################')
    print('################################################################')
    args.ratio = conf

    # Create the task vectors
    if method_name in ['ties_merging', 'tw_adamergingpp', 'lw_adamergingpp']:
        # TIES Merging
        from ties_merging_utils import *

        ft_checks = [torch.load(checkpoint_path + model_name + '/' + dataset_name + '/finetuned.pt').state_dict() for
                     dataset_name in exam_datasets]
        ptm_check = torch.load(pretrained_checkpoint).state_dict()
        check_parameterNamesMatch(ft_checks + [ptm_check])
        remove_keys = []

        flat_ft = torch.vstack([state_dict_to_vector(check, remove_keys) for check in ft_checks])
        flat_ptm = state_dict_to_vector(ptm_check, remove_keys)

        tv_flat_checks = flat_ft - flat_ptm
        assert check_state_dicts_equal(vector_to_state_dict(flat_ptm, ptm_check, remove_keys), ptm_check)
        assert all(
            [check_state_dicts_equal(vector_to_state_dict(flat_ft[i], ptm_check, remove_keys), ft_checks[i]) for i in
             range(len(ft_checks))])
        selected_entries, merged_tv = ties_merging_split(tv_flat_checks, reset_thresh=20, merge_func="dis-sum", )

        ties_task_vectors = []
        for vector_ in selected_entries:
            t_state_dict = vector_to_state_dict(vector_, ptm_check, remove_keys=remove_keys)
            ref_model = torch.load(pretrained_checkpoint)
            ref_model.load_state_dict(t_state_dict, strict=False)
            ties_task_vectors.append(ref_model.state_dict())

    elif method_name in ['weight_averaging', 'task_arithmetic', 'tw_adamerging', 'lw_adamerging']:
        # Task Vector
        task_vectors = [
            TaskVector(pretrained_checkpoint, checkpoint_path + model_name + '/' + dataset_name + '/finetuned.pt') for
            dataset_name in exam_datasets]

        # args.is_imageNet = False
        # pretrained_model = torch.load(pretrained_checkpoint).to('cuda')
        # mask = TATR_merging(args, task_vectors, pretrained_model, exam_datasets, order=1)

        # 23version
        # pretrained_model = torch.load(pretrained_checkpoint).to('cuda')
        # mask = NTK_merging(args, pretrained_checkpoint, exam_datasets, is_imageNet=False)

        # nnversion
        pretrained_model = torch.load(pretrained_checkpoint).to('cuda')
        mask = TATR_mergingnn(args, task_vectors, pretrained_model)

        for task_vector in task_vectors:
            progress = 0
            for name, pp in list(pretrained_model.named_parameters()):
                mask_params = mask[progress: progress + int(torch.tensor(pp.size()).prod().item())].view(pp.size())
                progress += int(torch.tensor(pp.size()).prod().item())
                task_vector.vector[name] *= mask_params

    else:
        print('method name error!')
        exit(-1)

    pretrained_model = torch.load(pretrained_checkpoint)
    pretrained_model_dic = pretrained_model.state_dict()

    model = ModelWrapper(pretrained_model)
    model = model.to(args.device)
    _, names = make_functional(model)

    paramslist = []
    paramslist += [tuple(v.detach().requires_grad_().cpu() for _, v in pretrained_model_dic.items())]  # pretrain
    if method_name in ['ties_merging', 'tw_adamergingpp', 'lw_adamergingpp']:
        paramslist += [tuple(v.detach().requires_grad_().cpu() for _, v in sd.items()) for i, sd in
                       enumerate(ties_task_vectors)]  # task vectors
    elif method_name in ['weight_averaging', 'task_arithmetic', 'tw_adamerging', 'lw_adamerging']:
        paramslist += [tuple(v.detach().requires_grad_().cpu() for _, v in sd.vector.items()) for i, sd in
                       enumerate(task_vectors)]  # task vectors

    torch.cuda.empty_cache()
    alpha_model = AlphaWrapper(paramslist, model, names, exam_datasets, args)

    optimizer = torch.optim.Adam(alpha_model.collect_trainable_params(), lr=1e-3, betas=(0.9, 0.999), weight_decay=0.)
    loss_func = torch.nn.L1Loss()

    loaders = []
    for dataset_name in learn_datasets:
        # shuffled test data
        dataset = get_dataset(dataset_name, pretrained_model.val_preprocess, location=args.data_location, batch_size=batch_size)
        dataloader = get_dataloader_shuffle(dataset)
        loaders.append(dataloader)

    for iteration in range(iterations):
        for i, dataset_name in enumerate(learn_datasets):
            # shuffled test data
            dataloader = loaders[i]

            try:
                finetuned = torch.load(checkpoint_path + args.model_name + '/' + dataset_name + '/finetuned.pt')
            except:
                finetuned = pickle.load(
                    open(checkpoint_path + args.model_name + '/' + dataset_name + '/finetuned.pt', 'rb'))

            finetuned = finetuned.to(args.device)
            finetuned.eval()

            exp_num = 0
            loss = 0.0
            for i, data in enumerate(dataloader):
                data = maybe_dictionarize(data)
                x = data['images'].to(args.device)

                outputs, features, _, _ = alpha_model(x, dataset_name)
                finetuned_features = finetuned(x).detach()

                loss += loss_func(features, finetuned_features)

                exp_num += x.shape[0]

                if exp_num >= 16:
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    torch.cuda.empty_cache()

                    break


    if (iteration+1) % 10 == 0:
        print('iter: ' + str(iteration+1))

    Total_ACC = 0.
    for dataset_name in learn_datasets:
        image_encoder = alpha_model.get_image_encoder()
        classification_head = alpha_model.get_classification_head(dataset_name)
        down_proj, up_proj = alpha_model.get_feature_mapping_to_head(dataset_name)
        metrics = eval_single_dataset_preprocess_mapping_head(image_encoder, classification_head, dataset_name,
                                                              args, down_proj, up_proj)
        Total_ACC += metrics['top1']
        log.info('dataset: ' + str(dataset_name) + ' ACC: ' + str(
            metrics['top1']))

    log.info('Avg ACC:' + str(Total_ACC / len(learn_datasets)) + '\n')




