import copy
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import time
import sys
sys.path.append('../')

import torch
import numpy as np

from task_vectors import TaskVector
from eval import eval_single_dataset
from args import parse_arguments

from bal_merging_utils import bal_mergingj


def create_log_dir(path, filename='log.txt'):
    import logging
    if not os.path.exists(path):
        os.makedirs(path)
    logger = logging.getLogger(path)
    logger.setLevel(logging.DEBUG)
    fh = logging.FileHandler(path+'/'+filename)
    fh.setLevel(logging.DEBUG)
    ch = logging.StreamHandler()
    ch.setLevel(logging.DEBUG)
    logger.addHandler(fh)
    logger.addHandler(ch)
    return logger

def normalize(x, dim=0):
    min_values, _ = torch.min(x, dim=dim, keepdim=True)
    max_values, _ = torch.max(x, dim=dim, keepdim=True)
    y = (x - min_values) / (max_values - min_values)
    return y

def clamp(x, min_ratio=0, max_ratio=0):
    if len(x.size())==1:
        d = x.size(0)
        sorted_x, _ = torch.sort(x)
        min=sorted_x[int(d * min_ratio)]
        max=sorted_x[int(d * (1-max_ratio)-1)]
    else:
        d = x.size(1)
        sorted_x, _ = torch.sort(x, dim=1)
        min=sorted_x[:, int(d * min_ratio)].unsqueeze(1)
        max=sorted_x[:, int(d * (1-max_ratio)-1)].unsqueeze(1)
    clamped_x= torch.clamp(x, min, max)
    return clamped_x

def act(x):
    y = torch.tanh(x)  # x**7; torch.relu(x)
    return y


#exam_datasets = ['SUN397', 'Cars']  # SUN397 | Cars | RESISC45 | EuroSAT | SVHN | GTSRB | MNIST | DTD
exam_datasets = ['SUN397', 'Cars', 'RESISC45', 'EuroSAT', 'SVHN', 'GTSRB', 'MNIST', 'DTD']  # SUN397 | Cars | RESISC45 | EuroSAT | SVHN | GTSRB | MNIST | DTD
test_datasets = ['SUN397', 'Cars', 'RESISC45', 'EuroSAT', 'SVHN', 'GTSRB', 'MNIST','DTD']  # SUN397 | Cars | RESISC45 | EuroSAT | SVHN | GTSRB | MNIST | DTD

args = parse_arguments()
args.is_imageNet = False
args.repeat = 3

# model = 'ViT-B-32'
# args.scaling_coef_ = 0.4
# args.exp_size = 2
# args.ratio = 4
# args.lambd = 5
# args.ratio_ln = 4
# args.lambd_ln = 10
# args.ratio_bias = 0.25

model = 'ViT-L-14'
args.scaling_coef_ = 0.6
args.exp_size = 3
args.batch_size = 16
args.ratio = 2
args.lambd = 50
args.ratio_ln = 2
args.lambd_ln = 50
args.ratio_bias = 0.15


args.data_location = '../data'
args.model = model
args.device = 'cuda'
args.save = '../checkpoints/checkpoints/' + model
args.logs_path = '../logs/' + model
pretrained_checkpoint = '../checkpoints/checkpoints/' + model + '/zeroshot.pt'

str_time_ = time.strftime('%Y%m%d_%H%M%S', time.localtime(time.time()))
log = create_log_dir(args.logs_path, 'log_{}_task_arithmetic.txt'.format(str_time_))
for conf in [0]:

    args.ratio_ln = conf
    print('################################################################')
    print('######################### Merging :', conf, ' ##############################')
    print('################################################################')
    print(args)

    acc_all = []
    for run in range(args.repeat):
        print('######################### Run :', run, ' ##############################')

        task_vectors = [
            TaskVector(pretrained_checkpoint, '../checkpoints/checkpoints/' + model + '/' + dataset_name + '/finetuned.pt')
            for dataset_name in exam_datasets
        ]
        pretrained_model = torch.load(pretrained_checkpoint).to('cuda')


        ################################################################
        ######################### Testing ##############################
        ################################################################

        from ties_merging_utils import *

        ft_checks = [
            torch.load('../checkpoints/checkpoints/' + model + '/' + dataset_name + '/finetuned.pt').state_dict()
            for dataset_name in exam_datasets]
        ptm_check = torch.load(pretrained_checkpoint).state_dict()
        check_parameterNamesMatch(ft_checks + [ptm_check])

        remove_keys = []
        print(f"Flattening out Checkpoints")
        flat_ft = torch.vstack([state_dict_to_vector(check, remove_keys) for check in ft_checks])
        flat_ptm = state_dict_to_vector(ptm_check, remove_keys)

        tv_flat_checks = flat_ft - flat_ptm
        assert check_state_dicts_equal(vector_to_state_dict(flat_ptm, ptm_check, remove_keys), ptm_check)
        assert all([check_state_dicts_equal(vector_to_state_dict(flat_ft[i], ptm_check, remove_keys), ft_checks[i])for i in range(len(ft_checks))])

        all_checks = tv_flat_checks
        n, d = all_checks.shape  # d = 113448705
        all_checks_abs = clamp(torch.abs(all_checks), min_ratio=0.01, max_ratio=0.01)
        clamped_all_checks = torch.sign(all_checks) * all_checks_abs
        att_ratio = 0.05
        all_checks_normalized = torch.sign(all_checks) * normalize(all_checks_abs, dim=1)
        intra = normalize(all_checks_abs, 1) ** 2
        intra = torch.exp(n * intra)
        inter = all_checks * torch.sum(all_checks_normalized, dim=0)
        inter = act(inter)
        balancing = intra * inter
        scale = normalize(clamp(balancing, 1-att_ratio, 0), dim=1)
        lams = torch.tensor([1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2]).unsqueeze(1)
        tvs = clamped_all_checks * lams
        merged_tv = torch.sum(tvs * scale, dim=0) / torch.clamp(torch.sum(scale, dim=0), min=1e-12)
        reference_state_dict = ptm_check



        merged_check = flat_ptm + merged_tv
        merged_state_dict = vector_to_state_dict(merged_check, ptm_check, remove_keys=remove_keys)

        image_encoder = torch.load(pretrained_checkpoint)
        image_encoder.load_state_dict(merged_state_dict, strict=False)

        # task_vector_sum = sum(task_vectors)
        # image_encoder = task_vector_sum.apply_to(pretrained_checkpoint, scaling_coef=args.scaling_coef_)
        # log.info('*'*20 + 'scaling_coef:' + str(args.scaling_coef_) + '*'*20)

        accs = []
        for dataset in exam_datasets:
            metrics = eval_single_dataset(image_encoder, dataset, args)
            log.info(str(dataset) + ':' + str(metrics.get('top1')*100)+'%')
            accs.append(metrics.get('top1')*100)
        log.info('Avg ACC:' + str(np.mean(accs)) + '%')

        acc_all.append(accs)

    # ================= 统计部分 ====================
    # 转换 acc_all 为 numpy 数组以便计算
    acc_all = np.array(acc_all)  # 维度: [repeat, len(exam_datasets)]

    # 1. 每个数据集的平均值和方差
    dataset_means = np.mean(acc_all, axis=0)  # 每列的平均值
    dataset_stds = np.std(acc_all, axis=0)  # 每列的方差

    print("\n############# ACC for each dataset #############")
    for i, dataset in enumerate(exam_datasets):
        print(f"{dataset}: AVG={dataset_means[i]:.2f}%, STD={dataset_stds[i]:.2f}%")

    # 2. 所有数据集平均准确率的5次运行平均方差
    overall_means = np.mean(acc_all, axis=1)  # 每次运行中所有数据集的平均准确率
    overall_std = np.std(overall_means)  # 5次运行平均准确率的方差
    overall_mean = np.mean(overall_means)  # 所有运行的总体平均值

    print("\n############# 总体统计结果 #############")
    print(f"Average performance across all datasets: AVG={overall_mean:.2f}%, STD={overall_std:.2f}%")





