import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import numpy as np
import torch
import time
import copy
import sys
sys.path.append('src/')
from src.task_vectors import TaskVector
from src.eval import eval_single_dataset
from src.args import parse_arguments
from utils import *

seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)


exam_datasets = ['SUN397', 'Cars', 'RESISC45', 'EuroSAT', 'SVHN', 'GTSRB', 'MNIST', 'DTD']
train_datasets = exam_datasets
eval_datasets = exam_datasets
model = 'ViT-B-32' #'ViT-B-16' #'ViT-B-32' # 'ViT-L-14'
args = parse_arguments()
args.base_dir = '../synergy'
args.data_location = os.path.join(args.base_dir, "data")
args.model = model
args.save = os.path.join(args.base_dir, "checkpoints", model)
args.logs_path = 'logs/' + model
args.pretrained_checkpoint = os.path.join(args.base_dir, "checkpoints", model, 'zeroshot.pt') 
args.scaling_coef = 1
args.DATASETS = exam_datasets
args.Target = range(len(train_datasets))

str_time_ = time.strftime('%Y%m%d_%H%M%S', time.localtime(time.time()))
log = create_log_dir(args.logs_path, 'log_{}_{}.txt'.format(str_time_, args.merge))
log.info("Merge method: {}, ".format(args.merge))
log.info("Configure: {}".format(args))
starttime = time.time()

################################################################################
task_vectors = [
    TaskVector(
        args.pretrained_checkpoint,
        os.path.join(args.base_dir, "checkpoints", model, dataset_name, "finetuned.pt")
    ) for dataset_name in train_datasets
]
task_vector_avg = copy.deepcopy(sum(task_vectors))  * (1/len(task_vectors))


# Merging methods
assert args.merge in ["TA", "TIES", "DARE", "Consensus_TA", "TSV-M", "ISO-C", "ISO-CTS", "SEAM-B", "SEAM-O"]
merge_methods = {
    "TA": TA, "TIES": TIES, "DARE": DARE, "Consensus_TA": Consensus_TA,
    "TSV-M": TSVM, "ISO-C": iso_c, "ISO-CTS": iso_cts, 
    "SEAM-B": SEAM_B, "SEAM-O": SEAM_O,
}
merge_methods[args.merge](task_vector_avg, task_vectors, args)


image_encoder = task_vector_avg.apply_to(args.pretrained_checkpoint, scaling_coef=args.scaling_coef)
log.info('*'*20 + 'Merge Method:' + str(args.merge) + '*'*20)

accs = []
for dataset in eval_datasets:
    metrics = eval_single_dataset(image_encoder, dataset, args)
    log.info(str(dataset) + ':' + str(metrics.get('top1')*100)+'%')
    accs.append(metrics.get('top1')*100)
log.info('Avg ACC:' + str(round(np.mean(accs),2)) + '%')
log.info('Time:' + str(time.time()-starttime))