import os
import numpy as np
import torch
import argparse
import copy
import wandb
import json

from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

from lib.dataset.imagenetr_utils import imagenet_r_transform, imagenet_r_mask, reverse_imagenet_r_mask, imagenet_a_mask, imagenet_o_mask
from lib.dataset.objectnet_dataset import objectnet_mask
from lib.dataset.objectnet_dataset import ObjectNetDataset
from lib.dataset.imagenet_v2 import ImageNetV2

from utils.utils import get_model, set_seed
from cka_pytorch.cka import CKACalculator
from lib.argument import parse_option
from lib.cka import calcuatelate_cka


if __name__ == '__main__':

    args = parse_option()
    # optionally resume from a checkpoint
    args.model_folder = os.path.join(args.model_dir, args.filename)
    print(args)
    if args.seed is not None:
        set_seed(args.seed)


    num_classes = 1000
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model1, preprocess, tokenizer = get_model(args.model, num_classes, args.patch_size, device, arch=args.arch, d_pre=args.d_pre, pretrained=True, reg='', mode=args.mode)
    model2 = copy.deepcopy(model1)

    if args.use_wandb:
        wandb.init(project=args.project_name,
                   config=args,
                   name=args.filename + '_CKA',
                   dir=args.model_folder)
    for i in range(1000):
        fpath = os.path.join(args.model_folder, f'checkpoint_{i}.pth.tar')
#        load_path = os.path.join(args.model_dir, args.filename, 'model_best.pth.tar')
        if os.path.exists(fpath):
            state_dict = torch.load(fpath)
            if 'state_dict' in state_dict:
                state_dict = state_dict['state_dict']
            model2.load_state_dict(state_dict)
            del state_dict
            print("Load", i)
     
            json_path = fpath.replace('.pth.tar', '.json')
            results = calculate_cka(model1, model2, preprocess, json_path, args)
            results['step'] = i
            if args.use_wandb:
                try:
                    wandb.log(results)
                except:
                    pass
            # empty cache
            torch.cuda.empty_cache()

    if args.use_wandb:
        wandb.run.finish()
