import os
import json
import torch

from args import parse_theta_lpips_arguments
from eval_functions import evaluate_model

from datasets.registry import get_dataset
from model.modeling import ImageEncoder

from utils import epochs

from loss.l2_sp import SPRegularization
from loss.lpips import LPIPS

from datasets.common import maybe_dictionarize
from tqdm import tqdm

from eval_functions import get_image_encoder

eval_datasets = [
    'CarsVal',
    'CIFAR10Val',
    'CIFAR100Val',
    'DTDVal',
    'EuroSATVal',
    'GTSRBVal',
    'MNISTVal',
    'RESISC45Val',
    'SVHNVal'
]

if __name__ == '__main__':
    args = parse_theta_lpips_arguments()
    args.batch_size = 512
    print (args)

    ckp_path = os.path.join(args.model_location, args.model, f'{args.train_dataset}Val')
    if args.finetune_loss in ['ce', 'l2sp']:
        ckp_path = os.path.join(ckp_path, f"{'zs' if args.zs_init else 'lp'}_init_{args.finetune_loss}{'_fzhd' if args.freeze_head else ''}")
    elif args.finetune_loss in ['flyp', 'flyp_ce']:
        ckp_path = os.path.join(ckp_path, "flyp" if args.finetune_loss == 'flyp' else "flypce")

    # Compute model iteration from index
    model = ImageEncoder(args, keep_lang=False)
    val_preprocess = model.val_preprocess
    train_set = get_dataset(f'{args.train_dataset}Val', val_preprocess, location=args.data_location, batch_size=128)
    num_batches = len(train_set.train_loader)
    train_epochs = epochs[args.train_dataset]
    num_steps = num_batches * train_epochs
    save_every = int(num_steps / 100)
    iteration_list = [it_index * save_every for it_index in [0, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70, 75, 80, 85, 90, 95, 100]]

    if args.finetune_loss in ['ce', 'l2sp']:
        model_file_name_prefix = f"finetuned_{'zs' if args.zs_init else 'lp'}_init_{args.finetune_loss}"
        pretrained_model_file_name = f"finetuned_{'zs' if args.zs_init else 'lp'}_init_{args.finetune_loss}_0{'_fzhd' if args.freeze_head else ''}.pt"
    elif args.finetune_loss in ['flyp', 'flyp_ce']:
        model_file_name_prefix = f"finetuned_{args.finetune_loss}"
        pretrained_model_file_name = f"finetuned_{args.finetune_loss}_0.pt"

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    pretrained_model_path = os.path.join(ckp_path, pretrained_model_file_name)
    pretrained_model = torch.load(pretrained_model_path).to(device)
    pretrained_image_encoder = get_image_encoder(args, pretrained_model).to(device)

    res_dict = {}
    for iteration in iteration_list:
        print (f'Computing metrics for training percentage: {int(iteration/save_every)}...')
        res_dict[iteration] = {}
        if args.finetune_loss in ['ce', 'l2sp']:
            model_file_name = f"{model_file_name_prefix}_{iteration}{'_fzhd' if args.freeze_head else ''}.pt"
        elif args.finetune_loss in ['flyp', 'flyp_ce']:
            model_file_name = f"{model_file_name_prefix}_{iteration}.pt"
        model_path = os.path.join(ckp_path, model_file_name)
        model = torch.load(model_path).to(device)
        model_image_encoder = get_image_encoder(args, model).to(device)

        # Compute weight space diff
        weight_space_diff = SPRegularization(pretrained_image_encoder, model_image_encoder)().item()
        res_dict[iteration]['theta_diff'] = weight_space_diff

        # Compute LPIPS for different datasets
        lpips_ls = LPIPS(pretrained_image_encoder, model_image_encoder, device)

        res_dict[iteration]['lpips'] = {}
        res_dict[iteration]['lpips_list'] = {}
        for eval_dataset_name in eval_datasets:
            print (f'Computing LPIPS on {eval_dataset_name}....')
            eval_dataset = get_dataset(eval_dataset_name, val_preprocess, location=args.data_location, batch_size=128)
            eval_dataset_loader = eval_dataset.test_loader
            lpips_vals = []
            with torch.no_grad():
                for batch in tqdm(eval_dataset_loader):
                    batch = maybe_dictionarize(batch)
                    data = batch['images'].to(device)
                    label = batch['labels'].to(device)
                    lpips_v = lpips_ls(data)
                    lpips_vals.append(lpips_v)
            lpips_diff = torch.cat(lpips_vals, dim=0).mean().item()
            lpips_vals = torch.cat(lpips_vals, dim=0).cpu().numpy().tolist()
            res_dict[iteration]['lpips'][eval_dataset_name] = lpips_diff
            res_dict[iteration]['lpips_list'][eval_dataset_name] = lpips_vals

            if (eval_dataset_name == f'{args.train_dataset}Val'):
                print (f'Computing LPIPS on train set too ...')
                train_set_loader = eval_dataset.train_loader
                lpips_vals = []
                with torch.no_grad():
                    for batch in tqdm(train_set_loader):
                        batch = maybe_dictionarize(batch)
                        data = batch['images'].to(device)
                        label = batch['labels'].to(device)
                        lpips_v = lpips_ls(data)
                        lpips_vals.append(lpips_v)
                lpips_diff = torch.cat(lpips_vals, dim=0).mean().item()
                lpips_vals = torch.cat(lpips_vals, dim=0).cpu().numpy().tolist()
                res_dict[iteration]['lpips'][f'{eval_dataset_name}_train'] = lpips_diff
                res_dict[iteration]['lpips_list'][f'{eval_dataset_name}_train'] = lpips_vals

    if args.finetune_loss in ['ce', 'l2sp']:
        res_save_path = os.path.join(args.res_store_path, f"theta_lpips_{args.train_dataset}_{'zs' if args.zs_init else 'lp'}_init_{args.finetune_loss}{'_fzhd' if args.freeze_head else ''}.json")
    elif args.finetune_loss in ['flyp', 'flyp_ce']:
        res_save_path = os.path.join(args.res_store_path, f"theta_lpips_{args.train_dataset}_{args.finetune_loss}.json")

    with open(res_save_path, 'w+') as fp:
        json.dump(res_dict, fp)

    print ("Results stored!")