import os
import sys

import timm
import torchvision

sys.path.append(".")

import torch

import numpy as np
import random
from analysis.torch_cka import CKA
from tools.misc import get_cnn_network, get_vit_network
from datatools.prepare_data import prepare_data_loader

if __name__ == "__main__":

    dataset = 'oxfordpets'
    output_dir = "analysis_results/feature_comparison"
    os.makedirs(output_dir, exist_ok=True)

    def seed_worker(worker_id):
        worker_seed = torch.initial_seed() % 2**32
        np.random.seed(worker_seed)
        random.saeed(worker_seed)

    g = torch.Generator()
    g.manual_seed(0)
    np.random.seed(0)
    random.seed(0)

    device = "cuda" if torch.cuda.is_available() else "cpu"
    #===============================================================
    model2 = get_cnn_network("resnet101", device=device)
    model1 = get_vit_network("vit_large", device=device)[0]

    check_layer2 = []
    check_layer1 = []

    for name, m in model2.named_modules():
        if isinstance(m, torchvision.models.resnet.BasicBlock) or isinstance(m, torchvision.models.resnet.Bottleneck):
            check_layer2.append(name)

    for name, m in model1.named_modules():
        if isinstance(m, timm.models.vision_transformer.Block):
            check_layer1.append(name)


    cka = CKA(model1, model2,
            model1_name="ViT-Large", model2_name="ResNet101",
            model1_layers=check_layer1, model2_layers=check_layer2,
            device=device)

    loaders, configs = prepare_data_loader(dataset, data_path="../data/vp_data", batch_size=128)
    dataloader = loaders['test']

    cka.compare(dataloader)

    cka.plot_results(save_path=f"{output_dir}/rn101_ViT_b_compare.png")
    result = cka.export()
    torch.save(result['CKA'], f"{output_dir}/rn101_ViT_b_compare.pth")
