import os

import numpy as np

# os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import time
import sys
root = '/data/common/task-arithmetic'
sys.path.append(root)

from task_vectors import TaskVector
from eval import eval_single_dataset
from args import parse_arguments

def create_log_dir(path, filename='log.txt'):
    import logging
    if not os.path.exists(path):
        os.makedirs(path)
    logger = logging.getLogger(path)
    logger.setLevel(logging.DEBUG)
    fh = logging.FileHandler(path+'/'+filename)
    fh.setLevel(logging.DEBUG)
    ch = logging.StreamHandler()
    ch.setLevel(logging.DEBUG)
    logger.addHandler(fh)
    logger.addHandler(ch)
    return logger

exam_datasets = ['SUN397', 'Cars', 'RESISC45', 'EuroSAT', 'SVHN', 'GTSRB', 'MNIST', 'DTD'] # SUN397 | Cars | RESISC45 | EuroSAT | SVHN | GTSRB | MNIST | DTD
finetuned_acc = {
    "SUN397": 0.7528463476,
    "Cars": 0.7766446959,
    "RESISC45": 0.9611111111,
    "EuroSAT": 0.9974074074,
    "SVHN": 0.9746081746,
    "GTSRB": 0.9873317498,
    "MNIST": 0.9969,
    "DTD": 0.7941489362,
}
# exam_datasets = ['Cars', 'RESISC45', 'EuroSAT'] # SUN397 | Cars | RESISC45 | EuroSAT | SVHN | GTSRB | MNIST | DTD
# exam_datasets = ["SUN397"] # SUN397 | Cars | RESISC45 | EuroSAT | SVHN | GTSRB | MNIST | DTD

model = 'ViT-B-32'
args = parse_arguments()
args.data_location = root + '/data'
args.model = model
args.save = root + '/task_vectors_checkpoints/' + model
args.logs_path = '../logs/pairwise/' + model
pretrained_checkpoint = root+'/task_vectors_checkpoints/'+model+'/zeroshot.pt'

str_time_ = time.strftime('%Y%m%d_%H%M%S', time.localtime(time.time()))
log = create_log_dir(args.logs_path, 'log_{}_task_arithmetic_1.txt'.format(str_time_))
log.info('dataset1,dataset2,acc1,acc2,finetuned_acc1,finetuned_acc2,diff1,diff2')

for i in range(len(exam_datasets)):
    for j in range(i+1, len(exam_datasets)):
        task_vectors = [
            TaskVector(pretrained_checkpoint, root+'/task_vectors_checkpoints/'+model+'/'+exam_datasets[i]+'/finetuned.pt'), 
            TaskVector(pretrained_checkpoint, root+'/task_vectors_checkpoints/'+model+'/'+exam_datasets[j]+'/finetuned.pt')
        ]

        task_vector_sum = sum(task_vectors)

        # scaling_coef_ = 0.5
        scaling_coef_ = 1

        image_encoder = task_vector_sum.apply_to(pretrained_checkpoint, scaling_coef=scaling_coef_)

        metric_i = eval_single_dataset(image_encoder, exam_datasets[i], args)
        metric_j = eval_single_dataset(image_encoder, exam_datasets[j], args)

        log.info(str(exam_datasets[i]) + ',' + str(exam_datasets[j]) + ',' + str(metric_i.get('top1')) + ',' + str(metric_j.get('top1')) + ',' + str(finetuned_acc[exam_datasets[i]]) + ',' + str(finetuned_acc[exam_datasets[j]]) + ',' + str(finetuned_acc[exam_datasets[i]]-metric_i.get('top1')) + ',' + str(finetuned_acc[exam_datasets[j]]-metric_j.get('top1')))
