import os
import torch

import numpy as np
from pathlib import Path

# export PYTHONPATH="$PYTHONPATH:$PWD"
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import time
import sys
root = '/data/common/task-arithmetic'
sys.path.append(root)

from eval import eval_single_dataset
from args import parse_arguments

from vision_datasets.registry import get_dataset
from localize_utils import *
import pickle
from vision_datasets.registry import split_train_into_train_val

def create_log_dir(path, filename='log.txt'):
    import logging
    if not os.path.exists(path):
        os.makedirs(path)
    logger = logging.getLogger(path)
    logger.setLevel(logging.DEBUG)
    fh = logging.FileHandler(path+'/'+filename)
    fh.setLevel(logging.DEBUG)
    ch = logging.StreamHandler()
    ch.setLevel(logging.DEBUG)
    logger.addHandler(fh)
    logger.addHandler(ch)
    return logger

exam_datasets = ['SUN397', 'Cars', 'RESISC45', 'EuroSAT', 'SVHN', 'GTSRB', 'MNIST', 'DTD'] # SUN397 | Cars | RESISC45 | EuroSAT | SVHN | GTSRB | MNIST | DTD
# exam_datasets = ['GTSRB'] # SUN397 | Cars | RESISC45 | EuroSAT | SVHN | GTSRB | MNIST | DTD

model = 'ViT-B-32'
train_mask = True
args = parse_arguments()
args.data_location = root + '/data'
args.model = model
args.save = root + '/task_vectors_checkpoints/' + model
# args.log = False
# args.save_mask = False
args.log = True
args.save_mask = True
args.graft_mode = "new"
args.valid_frac = 0.2
pretrained_checkpoint = root+'/task_vectors_checkpoints/'+model+'/zeroshot.pt'

finetuned_acc = {
    "SUN397": 0.7528463476,
    "Cars": 0.7766446959,
    "RESISC45": 0.9611111111,
    "EuroSAT": 0.9974074074,
    "SVHN": 0.9746081746,
    "GTSRB": 0.9873317498,
    "MNIST": 0.9969,
    "DTD": 0.7941489362,
}


str_time_ = time.strftime('%Y%m%d_%H%M%S', time.localtime(time.time()))

graft_args = parse_arguments()
graft_args.checkpoint_location = root+'/ckpt'
graft_args.l1_strength = 1
graft_args.learning_rate = 1e7
graft_args.sigmoid_bias = 3
graft_args.num_train_epochs = 20
graft_args.gradient_accumulation_steps = 128
graft_args.sparsity = 1e-5

if graft_args.sparsity is not None:
    sparsity_level_dict = {"SUN397": graft_args.sparsity, "Cars": graft_args.sparsity, "RESISC45": graft_args.sparsity, "EuroSAT": graft_args.sparsity, "SVHN": graft_args.sparsity, "GTSRB": graft_args.sparsity, "MNIST": graft_args.sparsity, "DTD": graft_args.sparsity}
    folder_name = model+'/'+str(args.graft_mode)+'/'+str(graft_args.sparsity)+'_'+str(graft_args.sigmoid_bias)+'_'+str(graft_args.l1_strength)+'_'+str(graft_args.num_train_epochs)+'_'+str(args.valid_frac)
else:
    sparsity_level_dict = {"SUN397": 1e-5, "Cars": 3e-2, "RESISC45": 1e-5, "EuroSAT": 1e-5, "SVHN": 1e-5, "GTSRB": 1e-5, "MNIST": 1e-5, "DTD": 5e-2}
    folder_name = model+'/'+str(args.graft_mode)+'/'+'vary'+str(graft_args.sigmoid_bias)+'_'+str(graft_args.l1_strength)+'_'+str(graft_args.num_train_epochs)+'_'+str(args.valid_frac)

mask_folder = root+'/masks/'+folder_name+'/'
args.logs_path = '../logs/pairwise/'

if args.log:
    log = create_log_dir(args.logs_path, 'log_{}_localize_stitch_1e-5.txt'.format(str_time_))
    log.info('dataset1,dataset2,acc1,acc2,finetuned_acc1,finetuned_acc2,diff1,diff2')

# start training masks
final_model = torch.load(pretrained_checkpoint)
pretrained_model = torch.load(pretrained_checkpoint)
pretrained_model_dic = pretrained_model.state_dict()

trainable_params = {}
frozen = ["model.positional_embedding", "model.text_projection", "model.logit_scale", "model.token_embedding.weight", "model.ln_final.weight", "model.ln_final.bias"]
for k, v in pretrained_model_dic.items():
    if k not in frozen:
        trainable_params[k] = v

for i in range(len(exam_datasets)):
    for j in range(i+1, len(exam_datasets)):

        finetuned_checkpoint_i = root+'/task_vectors_checkpoints/'+model+'/'+exam_datasets[i]+'/finetuned.pt'
        try:
            finetuned_model_i = torch.load(finetuned_checkpoint_i)
        except:
            finetuned_model_i = pickle.load(open(finetuned_checkpoint_i, 'rb'))
        
        finetuned_checkpoint_j = root+'/task_vectors_checkpoints/'+model+'/'+exam_datasets[j]+'/finetuned.pt'
        try:
            finetuned_model_j = torch.load(finetuned_checkpoint_j)
        except:
            finetuned_model_j = pickle.load(open(finetuned_checkpoint_j, 'rb'))
        
        finetuned_models = [finetuned_model_i, finetuned_model_j]

        masks = [torch.load(mask_folder+exam_datasets[i]+'_mask.pt'), torch.load(mask_folder+exam_datasets[j]+'_mask.pt')]

        model_ = torch.load(pretrained_checkpoint)
        stitcher = Stitcher(trainable_params, model_, pretrained_model, finetuned_models, masks)
        image_encoder = stitcher.interpolate_models()

        metric_i = eval_single_dataset(image_encoder, exam_datasets[i], args)
        metric_j = eval_single_dataset(image_encoder, exam_datasets[j], args)

        log.info(str(exam_datasets[i]) + ',' + str(exam_datasets[j]) + ',' + str(metric_i.get('top1')) + ',' + str(metric_j.get('top1')) + ',' + str(finetuned_acc[exam_datasets[i]]) + ',' + str(finetuned_acc[exam_datasets[j]]) + ',' + str(finetuned_acc[exam_datasets[i]]-metric_i.get('top1')) + ',' + str(finetuned_acc[exam_datasets[j]]-metric_j.get('top1')))
