import os
import gc
import numpy as np
import torch
import argparse
import copy
import wandb
import json

from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import torch.nn as nn

from lib.dataset.imagenetr_utils import imagenet_r_transform, imagenet_r_mask, reverse_imagenet_r_mask, imagenet_a_mask, imagenet_o_mask
from lib.dataset.objectnet_dataset import objectnet_mask
from lib.dataset.objectnet_dataset import ObjectNetDataset
from lib.dataset.imagenet_v2 import ImageNetV2

from utils.utils import get_model, set_seed
from lib.cka_pytorch.cka import CKACalculator
from lib.argument import parse_option

def get_dataset(root, dataset_name, preprocess, batch_size=128, num_workers=4):
    path = os.path.join(root, dataset_name)
    if dataset_name == 'objectnet-1.0':
        dataset = ObjectNetDataset(root=path, transform=preprocess)
    elif dataset_name == 'objectnet-v2':
        dataset = ObjectNetDataset(root=os.path.join(root, 'objectnet-1.0'), transform=preprocess, reindex=True)
    elif dataset_name == 'imagenet-v2':
        dataset = ImageNetV2(path, transform=preprocess)
    else:
        dataset = ImageFolder(path, transform=preprocess)

    data_loader = DataLoader(dataset,
                            batch_size=batch_size, pin_memory=True,
                            num_workers=num_workers, shuffle=False)
    return data_loader


def calculate_cka(model1, model2, preprocess, save_path, args): 
    cka_target = args.cka_target
    layers = (nn.Linear, nn.LayerNorm, nn.Conv2d)
    datasets = ['imagenet/val',  'imagenet-a', 'imagenet-r', 'imagenet-sketch', 'imagenet-cartoon', 'imagenet-drawing', 'imagenet-v2', 'objectnet-v2']
    if cka_target == 'imagenet-c':
        datasets = []
    elif cka_target != None:
        datasets = [cka_target]
    else:
        datasets = ['imagenet/val',  'imagenet-a', 'imagenet-r', 'imagenet-sketch', 'imagenet-cartoon', 'imagenet-drawing', 'imagenet-v2', 'objectnet-v2']
    results = {}
    if os.path.exists(save_path):
        results = torch.load(save_path)

    for dataset in datasets:
        if dataset in results:
            continue
        print(dataset)
        dataloader = get_dataset('datasets', dataset, preprocess, args.batch_size*2, args.num_workers)
        calculator = CKACalculator(model1, model2, dataloader, num_epochs=1, hook_layer_types=layers)
        cka_matrix = calculator.calculate_cka_matrix()
        results[dataset] = cka_matrix.cpu() #.flatten()) #cpu().numpy()

        if 'module_names_X' not in results:
            results['module_names_X'] = calculator.module_names_X
            results['module_names_Y'] = calculator.module_names_Y
        # Save data.
        torch.save(results, save_path)


        print(dataset, cka_matrix)
        calculator.reset()

        del calculator, dataloader
        gc.collect()
        torch.cuda.empty_cache()

    
#    return results
    if (cka_target is not None and cka_target != 'imagenet-c'):
        return results
    imagenet_c = [
            'gaussian_noise', 'shot_noise', 'impulse_noise', 'defocus_blur', 'glass_blur', 'motion_blur', 'zoom_blur', 'snow', 'frost', 'fog', 'brightness', 'contrast', 'elastic_transform', 'pixelate', 'jpeg_compression']
    for name in imagenet_c:
        for severity in range(1,6):
            dataset_name = f'imagenet-c/{name}/{severity}'
            if dataset_name in results:
                continue
            print(dataset_name)
            _path = os.path.join(args.root, dataset_name)
#            if name in json_output and not args.collect_features:
#                outputs[name] = json_output[name]
#                continue
            dataset = ImageFolder(_path, transform=preprocess)

            dataloader = DataLoader(dataset,
                                    batch_size=args.batch_size*2, pin_memory=True,
                                    num_workers=args.num_workers, shuffle=False)
            calculator = CKACalculator(model1, model2, dataloader, num_epochs=1, hook_layer_types=layers)
            cka_matrix = calculator.calculate_cka_matrix().detach()
#            results[dataset_name] = float(cka_matrix.flatten())
            results[dataset_name] = cka_matrix.cpu() #.flatten()) #cpu().numpy()
            if 'module_names_X' not in results:
                results['module_names_X'] = calculator.module_names_X
                results['module_names_Y'] = calculator.module_names_Y
            print(dataset, cka_matrix)

            torch.save(results, save_path)
            calculator.reset()
            del calculator, dataloader
            gc.collect()
            torch.cuda.empty_cache()

    print(results)
    return results

