import os
import torch

import numpy as np
from pathlib import Path

# export PYTHONPATH="$PYTHONPATH:$PWD"
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import time
import sys
root = '/data/common/task-arithmetic'
sys.path.append(root)

from transformers import set_seed
# set_seed(42)

from eval import eval_single_dataset
from args import parse_arguments

from vision_datasets.registry import get_dataset
from localize_utils import *
import pickle
from vision_datasets.registry import split_train_into_train_val, create_k_shot_dataset

def create_log_dir(path, filename='log.txt'):
    import logging
    if not os.path.exists(path):
        os.makedirs(path)
    logger = logging.getLogger(path)
    logger.setLevel(logging.DEBUG)
    fh = logging.FileHandler(path+'/'+filename)
    fh.setLevel(logging.DEBUG)
    ch = logging.StreamHandler()
    ch.setLevel(logging.DEBUG)
    logger.addHandler(fh)
    logger.addHandler(ch)
    return logger

exam_datasets = ['SUN397', 'Cars', 'RESISC45', 'EuroSAT', 'SVHN', 'GTSRB', 'MNIST', 'DTD'] # SUN397 | Cars | RESISC45 | EuroSAT | SVHN | GTSRB | MNIST | DTD
# exam_datasets = ['SUN397', 'RESISC45', 'EuroSAT', 'SVHN', 'MNIST', 'DTD'] # SUN397 | Cars | RESISC45 | EuroSAT | SVHN | GTSRB | MNIST | DTD
# exam_datasets = ['GTSRB'] # SUN397 | Cars | RESISC45 | EuroSAT | SVHN | GTSRB | MNIST | DTD

args = parse_arguments()

# model = 'ViT-L-14' # 'ViT-L-14', 'ViT-B-32'

model = 'ViT-B-32' # 'ViT-L-14', 'ViT-B-32'

train_mask = True
args.data_location = root + '/data'
args.model = model
args.save = root + '/task_vectors_checkpoints/' + model
# args.log = False
# args.save_mask = False
args.log = True
args.save_mask = True
args.graft_mode = "old"
args.valid_frac = 0.2
n_shot = 64
pretrained_checkpoint = root+'/task_vectors_checkpoints/'+model+'/zeroshot.pt'

str_time_ = time.strftime('%Y%m%d_%H%M%S', time.localtime(time.time()))

graft_args = parse_arguments()
graft_args.checkpoint_location = root+'/ckpt'
if model == "ViT-L-14":
    graft_args.sigmoid_bias = 1
    args.batch_size = 32
elif model == "ViT-B-32":
    graft_args.sigmoid_bias = 4.5
    args.batch_size = 128
graft_args.l1_strength = 1
graft_args.learning_rate = 1e7
graft_args.num_train_epochs = 20
graft_args.sparsity = 1e-5

if graft_args.sparsity is not None:
    l1_strength_dict = {"SUN397": 0.01, "Cars": 0.01, "RESISC45": 0.01, "EuroSAT": 0.01, "SVHN": 0.01, "GTSRB": 0.01, "MNIST": 0.01, "DTD": 0.01}
    sparsity_level_dict = {"SUN397": graft_args.sparsity, "Cars": graft_args.sparsity, "RESISC45": graft_args.sparsity, "EuroSAT": graft_args.sparsity, "SVHN": graft_args.sparsity, "GTSRB": graft_args.sparsity, "MNIST": graft_args.sparsity, "DTD": graft_args.sparsity}
    folder_name = model+'/'+str(args.graft_mode)+'/'+str(graft_args.sparsity)+'_'+str(graft_args.sigmoid_bias)+'_'+str(graft_args.l1_strength)+'_'+str(graft_args.num_train_epochs)+'_'+str(args.valid_frac)
else:
    sparsity_level_dict = {"SUN397": 1e-5, "Cars": 3e-2, "RESISC45": 1e-5, "EuroSAT": 1e-5, "SVHN": 1e-5, "GTSRB": 1e-5, "MNIST": 1e-5, "DTD": 5e-2}
    folder_name = model+'/'+str(args.graft_mode)+'/'+'vary'+str(graft_args.sigmoid_bias)+'_'+str(graft_args.l1_strength)+'_'+str(graft_args.num_train_epochs)+'_'+str(args.valid_frac)

# sigmoid_dict = {"SUN397": 4, "Cars": 4, "RESISC45": 4, "EuroSAT": 4, "SVHN": 4, "GTSRB": 4, "MNIST": 4, "DTD": 4}

mask_folder = root+f'/masks/{n_shot}shot/'+folder_name+'/'
args.logs_path = f'../logs/{n_shot}shot/'+folder_name
# args.logs_path = f'../logs/ViT-B-32/time/localize_{graft_args.num_train_epochs}'

if args.log:
    log = create_log_dir(args.logs_path, 'log_{}_localize_stitch.txt'.format(str_time_))

# start training masks
final_model = torch.load(pretrained_checkpoint)
pretrained_model = torch.load(pretrained_checkpoint)
pretrained_model_dic = pretrained_model.state_dict()

trainable_params = {}
frozen = ["model.positional_embedding", "model.text_projection", "model.logit_scale", "model.token_embedding.weight", "model.ln_final.weight", "model.ln_final.bias"]
for k, v in pretrained_model_dic.items():
    if k not in frozen:
        trainable_params[k] = v

start_time = time.time()
masks, finetuned_models, proportions, tests = [], [], [], []
for dataset_name in exam_datasets:
    graft_args.sparsity_level = sparsity_level_dict[dataset_name]
    finetuned_checkpoint = root+'/task_vectors_checkpoints/'+model+'/'+dataset_name+'/finetuned.pt'
    try:
        finetuned_model = torch.load(finetuned_checkpoint)
    except:
        finetuned_model = pickle.load(open(finetuned_checkpoint, 'rb'))

    if args.save_mask:
        Path(mask_folder).mkdir(parents=True, exist_ok=True)        
    if train_mask:
        base_dataset = get_dataset(dataset_name, final_model.val_preprocess, location=args.data_location, batch_size=args.batch_size)
        # dataset = split_train_into_train_val(
        #             base_dataset, dataset_name, args.batch_size, num_workers=2, val_fraction=args.valid_frac, max_val_samples=5000)
        # valset = dataset.test_dataset

        valset = create_k_shot_dataset(base_dataset, num_shots=n_shot)
        print("Total samples used for mask computation: ", len(valset))
        
        val_dataloader = torch.utils.data.DataLoader(
            valset,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=2
        )

        graft_args.gradient_accumulation_steps = len(valset) // args.batch_size + 1

        if args.graft_mode == "new":
            localizer = Localizer(trainable_params, final_model, pretrained_model, finetuned_model, dataset_name, args, graft_args)
        elif args.graft_mode == "old":
            localizer = Localizer_og(trainable_params, final_model, pretrained_model, finetuned_model, dataset_name, args, graft_args)
        mask, proportion, test = localizer.train_graft(val_dataloader, dataset_name)
        if args.save_mask:
            torch.save(mask, mask_folder+dataset_name+'_mask.pt')
    else:
        localizer = Localizer(trainable_params, final_model, pretrained_model, finetuned_model, dataset_name, args, graft_args)
        mask = torch.load(mask_folder+dataset_name+'_mask.pt')
        localizer.mask = mask
        _, proportion = localizer.interpolate_model(return_mask=True)
        test = eval_single_dataset(localizer.model, dataset_name, args)["top1"]
    
    masks.append(mask)
    finetuned_models.append(finetuned_model)
    proportions.append(proportion.cpu().item())
    tests.append(test)

localize_time = time.time() - start_time
final_model = torch.load(pretrained_checkpoint)
stitcher = Stitcher(trainable_params, final_model, pretrained_model, finetuned_models, masks)
image_encoder = stitcher.interpolate_models()
stitch_time = time.time() - start_time - localize_time

# image_encoder.save(root+'/merged_models/ViT-B-32/localize_stitch_'+str(args.graft_mode)+'_'+str(graft_args.sparsity)+'_'+str(graft_args.sigmoid_bias)+'_'+str(graft_args.l1_strength)+'_'+str(graft_args.num_train_epochs)+'_'+str(args.valid_frac) + '.pt')
# image_encoder.save(root+'/merged_models/ViT-B-32/localize_stitch_'+str(args.graft_mode)+'_'+str(graft_args.sparsity)+'_'+str(graft_args.sigmoid_bias)+'_'+str(graft_args.l1_strength)+'_'+str(graft_args.num_train_epochs)+'_'+str(args.valid_frac) + '.pt')
# image_encoder.save(root+'/merged_models/ViT-B-32/localize_stitch_new_1e-05_3_10_20_0.2.pt')

accs = []
for i in range(len(exam_datasets)):
    dataset = exam_datasets[i]
    metrics = eval_single_dataset(image_encoder, dataset, args)
    accs.append(metrics.get('top1'))
    # print(str(dataset)+','+str(metrics.get('top1')))
    if args.log:
        log.info(str(dataset)+','+str(tests[i])+','+str(proportions[i])+','+str(metrics.get('top1')))
if args.log:
    log.info('Avg'+','+str(np.mean(tests))+','+str(np.mean(proportions))+','+str(np.mean(accs)))
    log.info('sparsity_level'+','+str(list(sparsity_level_dict.values())))
    log.info('Localize time: '+str(localize_time))
    log.info('Stitch time: '+str(stitch_time))
