import os

import torch
import pickle

from merging_cofficient import get_merging_cofficients

os.environ["CUDA_VISIBLE_DEVICES"] = "1"

from tatr_merging_utils import NTK_merging, TATR_merging, TATR_mergingnn         # 33version

from task_vectors import TaskVector

import torch.nn as nn
import time
import sys
sys.path.append('../')

from eval import eval_single_dataset, eval_single_dataset_head, eval_single_dataset_preprocess_head
from args import parse_arguments

'''
使用pre的参数
'''

def create_log_dir(path, filename='log.txt'):
    import logging
    if not os.path.exists(path):
        os.makedirs(path)
    logger = logging.getLogger(path)
    logger.setLevel(logging.DEBUG)
    fh = logging.FileHandler(path+'/'+filename)
    fh.setLevel(logging.DEBUG)
    ch = logging.StreamHandler()
    ch.setLevel(logging.DEBUG)
    logger.addHandler(fh)
    logger.addHandler(ch)
    return logger

exam_datasets = ['SUN397', 'Cars', 'RESISC45', 'EuroSAT', 'SVHN', 'GTSRB', 'MNIST', 'DTD'] # SUN397 | Cars | RESISC45 | EuroSAT | SVHN | GTSRB | MNIST | DTD
model = 'ViT-L-14'
args = parse_arguments()
args.data_location = '../data'
args.model = model
args.batch_size = 16
args.exp_size = 128
args.save = '../checkpoints/checkpoints/' + model
args.logs_path = '../logs/' + model
args.ratio = 1
pretrained_checkpoint = '../checkpoints/checkpoints/'+model+'/zeroshot.pt'

str_time_ = time.strftime('%Y%m%d_%H%M%S', time.localtime(time.time()))
log = create_log_dir(args.logs_path, 'log_{}_Layer_wise_AdaMergingPP.txt'.format(str_time_))
args.log = log

from ties_merging_utils import *




def del_attr(obj, names):
    if len(names) == 1:
        delattr(obj, names[0])
    else:
        del_attr(getattr(obj, names[0]), names[1:])

def set_attr(obj, names, val):
    if len(names) == 1:
        setattr(obj, names[0], val)
    else:
        set_attr(getattr(obj, names[0]), names[1:], val)

def make_functional(mod):
    orig_params = tuple(mod.parameters())
    names = []
    for name, p in list(mod.named_parameters()):
        del_attr(mod, name.split("."))
        names.append(name)
    return orig_params, names

def load_weights(mod, names, params):
    for name, p in zip(names, params):
        set_attr(mod, name.split("."), p)


class ModelWrapper(torch.nn.Module):
    def __init__(self, model, initial_weights=None):
        super(ModelWrapper, self).__init__()
        self.model = model

        if hasattr(self.model, 'transformer'):
            delattr(self.model, 'transformer')

    def forward(self, images):
        features = self.model(images)
        return features

from heads import get_classification_head
class AdaMerging(torch.nn.Module):
    def __init__(self, paramslist, model, names, exam_datasets):
        super(AdaMerging, self).__init__()
        self.paramslist = paramslist
        self.model = model
        self.names = names
        self.pretrain_lambdas = torch.ones(len(paramslist[0]), 1)
        prior = 0.3
        rlambdas = torch.ones(len(paramslist[0]), len(paramslist)-1) * prior  # (1 * tasks)


        lambda_pre = get_merging_cofficients(method='lw_adamergingpp', model='ViT-B-16')


        self.lambdas_raw = torch.nn.Parameter(torch.Tensor(lambda_pre))

        self.classifier = []
        for dataset_name in exam_datasets:
            classification_head = get_classification_head(args, dataset_name)
            layer_name = 'classifier_{}'.format(dataset_name)
            self.add_module(layer_name, classification_head.to(args.device))
            self.classifier.append(layer_name)

    def lambdas(self):
        task_lambdas = torch.clamp(self.lambdas_raw, min=0.0, max=1.0)
        lambdass = torch.cat((self.pretrain_lambdas.to(task_lambdas), task_lambdas), 1)
        return lambdass

    def collect_trainable_params(self):
        return [self.lambdas_raw]

    def get_classification_head(self, dataset_name):
        layer_name = 'classifier_{}'.format(dataset_name)
        classification_head = getattr(self, layer_name)
        return classification_head

    def get_image_encoder(self):
        alph = self.lambdas()
        params = tuple(sum(tuple(pi * lambdasi for pi, lambdasi in zip(p, alph[j].cpu()))) for j, p in enumerate(zip(*self.paramslist)))
        params = tuple(p.cuda(0) for p in params)
        load_weights(self.model, self.names, params)
        return self.model

    def forward(self, inp, dataset_name):
        alph = self.lambdas()
        params = tuple(sum(tuple(pi * lambdasi for pi, lambdasi in zip(p, alph[j].cpu()))) for j, p in enumerate(zip(*self.paramslist)))

        params = tuple(p.cuda() for p in params)
        load_weights(self.model, self.names, params)
        feature = self.model(inp)

        layer_name = 'classifier_{}'.format(dataset_name)
        classification_head = getattr(self, layer_name)
        out = classification_head(feature)

        return out

def softmax_entropy(x):
    return -(x.softmax(1) * x.log_softmax(1)).sum(1)



K = 20
merge_func = "dis-sum"

for conf in [0.9, 0.95, 0.98, 0.99, 0.995, 0.998, 0.999]:
    print('################################################################')
    print('######################### Merging :', conf, ' ##############################')
    print('################################################################')
    args.ratio = conf

    ft_checks = []
    for dataset_name in exam_datasets:
        file_name = '../checkpoints/checkpoints/' + model + '/' + dataset_name + '/finetuned.pt'
        if file_name == '../checkpoints/checkpoints/' + 'ViT-B-16' + '/' + 'Cars' + '/finetuned.pt':
            ft_checks.append(pickle.load(open(file_name, 'rb')).state_dict())
        else:
            ft_checks.append(torch.load(file_name).state_dict())
    ptm_check = torch.load(pretrained_checkpoint).state_dict()

    check_parameterNamesMatch(ft_checks + [ptm_check])

    # 33version
    # task_vectors = [
    #     TaskVector(pretrained_checkpoint, '../checkpoints/checkpoints/'+model+'/'+dataset_name+'/finetuned.pt') for dataset_name in exam_datasets
    # ]
    # args.is_imageNet = False
    # pretrained_model = torch.load(pretrained_checkpoint).to('cuda')
    # mask = TATR_merging(args, task_vectors, pretrained_model, exam_datasets, order=1)

    # nnversion
    task_vectors = [
        TaskVector(pretrained_checkpoint, '../checkpoints/checkpoints/'+model+'/'+dataset_name+'/finetuned.pt') for dataset_name in exam_datasets
    ]
    args.is_imageNet = False
    pretrained_model = torch.load(pretrained_checkpoint).to('cuda')
    mask = TATR_mergingnn(args, task_vectors, pretrained_model)


    remove_keys = []
    print(f"Flattening out Checkpoints")
    flat_ft = torch.vstack([state_dict_to_vector(check, remove_keys) for check in ft_checks])
    flat_ptm = state_dict_to_vector(ptm_check, remove_keys)

    tv_flat_checks = flat_ft - flat_ptm

    assert check_state_dicts_equal(vector_to_state_dict(flat_ptm, ptm_check, remove_keys), ptm_check)
    assert all([check_state_dicts_equal(vector_to_state_dict(flat_ft[i], ptm_check, remove_keys), ft_checks[i])for i in range(len(ft_checks))])

    selected_entries, merged_tv = ties_merging_split(tv_flat_checks, reset_thresh=K, merge_func=merge_func,)

    ties_task_vectors = []
    for vector_ in selected_entries:
        t_state_dict = vector_to_state_dict(vector_, ptm_check, remove_keys=remove_keys)
        ref_model = torch.load(pretrained_checkpoint)
        ref_model.load_state_dict(t_state_dict, strict=False)
        ties_task_vectors.append(ref_model.state_dict())

    for task_vector in ties_task_vectors:
        progress = 0
        for name, pp in list(pretrained_model.named_parameters()):
            mask_params = mask[progress: progress + int(torch.tensor(pp.size()).prod().item())].view(pp.size())
            progress += int(torch.tensor(pp.size()).prod().item())
            task_vector[name] *= mask_params

    pretrained_model = torch.load(pretrained_checkpoint)
    pretrained_model_dic = pretrained_model.state_dict()

    model_w = ModelWrapper(pretrained_model, exam_datasets)
    model_w = model_w.to(args.device)
    _, names = make_functional(model_w)

    paramslist = []
    paramslist += [tuple(v.detach().requires_grad_().cpu() for _, v in pretrained_model_dic.items())] # pretrain
    paramslist += [tuple(v.detach().requires_grad_().cpu() for _, v in tv.items())  for i, tv in enumerate(ties_task_vectors)] # task vectors
    torch.cuda.empty_cache()
    adamerging_mtl_model = nn.DataParallel(AdaMerging(paramslist, model_w, names, exam_datasets), device_ids=[0])
    adamerging_mtl_model = adamerging_mtl_model.cuda()


    print('collect_trainable_params:')
    print(list(adamerging_mtl_model.module.collect_trainable_params()))


    Total_ACC = 0.
    for dataset_name in exam_datasets:
        image_encoder = adamerging_mtl_model.module.get_image_encoder()
        classification_head = adamerging_mtl_model.module.get_classification_head(dataset_name)
        metrics = eval_single_dataset_preprocess_head(image_encoder, classification_head, dataset_name, args, is_train=True)
        Total_ACC += metrics['top1']
        log.info('dataset: ' + str(dataset_name) + ' ACC: ' + str(metrics['top1']))
    log.info(' Avg ACC:' + str(Total_ACC / len(exam_datasets)) + '\n')
