import copy
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import time
import sys
sys.path.append('../')

import torch
import numpy as np

from task_vectors import TaskVector
from eval import eval_single_dataset
from args import parse_arguments

from opt_merging_utils import opt_merging


def create_log_dir(path, filename='log.txt'):
    import logging
    if not os.path.exists(path):
        os.makedirs(path)
    logger = logging.getLogger(path)
    logger.setLevel(logging.DEBUG)
    fh = logging.FileHandler(path+'/'+filename)
    fh.setLevel(logging.DEBUG)
    ch = logging.StreamHandler()
    ch.setLevel(logging.DEBUG)
    logger.addHandler(fh)
    logger.addHandler(ch)
    return logger


# exam_datasets = ['SUN397', 'Cars']  # SUN397 | Cars | RESISC45 | EuroSAT | SVHN | GTSRB | MNIST | DTD
exam_datasets = ['SUN397', 'Cars', 'RESISC45', 'EuroSAT', 'SVHN', 'GTSRB', 'MNIST', 'DTD']  # SUN397 | Cars | RESISC45 | EuroSAT | SVHN | GTSRB | MNIST | DTD
test_datasets = ['SUN397', 'Cars', 'RESISC45', 'EuroSAT', 'SVHN', 'GTSRB', 'MNIST','DTD']  # SUN397 | Cars | RESISC45 | EuroSAT | SVHN | GTSRB | MNIST | DTD

args = parse_arguments()


model = 'ViT-B-32'
args.scaling_coef_ = 1.2
args.repeat = 1
args.exp_size = 64

# model = 'ViT-L-14'
# args.scaling_coef_ = 1.5
# args.repeat = 3
# args.exp_size = 64
# args.batch_size = 16


args.data_location = '../data'
args.model = model
args.device = 'cuda'
args.save = '../checkpoints/checkpoints/' + model
args.logs_path = '../logs/' + model
pretrained_checkpoint = '../checkpoints/checkpoints/' + model + '/zeroshot.pt'

str_time_ = time.strftime('%Y%m%d_%H%M%S', time.localtime(time.time()))
log = create_log_dir(args.logs_path, 'log_{}_task_arithmetic.txt'.format(str_time_))
for conf in [20, 40, 60, 80, 100, 120, 140, 160]:

    print('################################################################')
    print('######################### Merging :', conf, ' ##############################')
    print('################################################################')
    print(args)

    acc_all = []
    for run in range(args.repeat):
        print('######################### Run :', run, ' ##############################')

        task_vectors = [
            TaskVector(pretrained_checkpoint, '../checkpoints/checkpoints/' + model + '/' + dataset_name + '/finetuned.pt')
            for dataset_name in exam_datasets
        ]
        pretrained_model = torch.load(pretrained_checkpoint).to('cuda')

        ################################################################
        ######################### Testing ##############################
        ################################################################
        opt_vector = opt_merging(args, task_vectors, pretrained_checkpoint, exam_datasets)

        num_params = 0
        for name, pp in list(pretrained_model.named_parameters()):
            if opt_vector.vector[name] is None:
                continue
            elif name == 'model.visual.class_embedding' \
                    or name == 'model.visual.conv1.weight' \
                    or name == 'model.visual.positional_embedding' \
                    or name == 'model.visual.proj' \
                    or len(opt_vector.vector[name].shape) == 1 \
                    or 'bias' in name:
                if num_params < conf:
                    opt_vector.vector[name] *= 0
                num_params += 1
            elif 'ln' in name and 'weight' in name:
                if num_params < conf:
                    opt_vector.vector[name] *= 0
                num_params += 1

            else:
                if num_params < conf:
                    opt_vector.vector[name] *= 0
                num_params += 1
        print('num_params:', num_params)

        image_encoder = opt_vector.apply_to(pretrained_checkpoint, scaling_coef=args.scaling_coef_)
        log.info('*'*20 + 'scaling_coef:' + str(args.scaling_coef_) + '*'*20)

        accs = []
        for dataset in exam_datasets:
            metrics = eval_single_dataset(image_encoder, dataset, args)
            log.info(str(dataset) + ':' + str(metrics.get('top1')*100)+'%')
            accs.append(metrics.get('top1')*100)
        log.info('Avg ACC:' + str(np.mean(accs)) + '%')

        acc_all.append(accs)

    # ================= 统计部分 ====================
    # 转换 acc_all 为 numpy 数组以便计算
    acc_all = np.array(acc_all)  # 维度: [repeat, len(exam_datasets)]

    # 1. 每个数据集的平均值和方差
    dataset_means = np.mean(acc_all, axis=0)  # 每列的平均值
    dataset_stds = np.std(acc_all, axis=0)  # 每列的方差

    print("\n############# ACC for each dataset #############")
    for i, dataset in enumerate(exam_datasets):
        print(f"{dataset}: AVG={dataset_means[i]:.2f}%, STD={dataset_stds[i]:.2f}%")

    # 2. 所有数据集平均准确率的5次运行平均方差
    overall_means = np.mean(acc_all, axis=1)  # 每次运行中所有数据集的平均准确率
    overall_std = np.std(overall_means)  # 5次运行平均准确率的方差
    overall_mean = np.mean(overall_means)  # 所有运行的总体平均值

    print("\n############# 总体统计结果 #############")
    print(f"Average performance across all datasets: AVG={overall_mean:.2f}%, STD={overall_std:.2f}%")





