'''Computes similarity matrix'''
import os
import sys
import argparse
import random
import time
import torch
import logging
import numpy as np

from sim_measures.similarity_methods import cosine_similarity
from sim_measures.similarity_methods import pnka
from sim_measures.trace_cka import cka_trace_tensor

from robustness import defaults
from robustness.datasets import DATASETS
from robustness.main import setup_args


def get_args():
    parser = argparse.ArgumentParser()

    parser.add_argument('--sim_method', required=False, default='cka',
        choices=['cka', 'pnka', 'cos'], type=str, help='Similarity method to be used.')
    
    parser.add_argument('--features_path_model1', required=True, type=str, help='Path to save features of model 1')
    parser.add_argument('--features_path_model2', required=True, type=str, help='Path to save features of model 2')

    parser.add_argument('--nb_layers_model1', required=True, type=int,
        help='Number of layers in model1')
    parser.add_argument('--nb_layers_model2', required=False, default=None, type=int,
        help='Number of layers in model2. If None, considers it is the same as model1')

    parser.add_argument('--seed', default=0, type=int, help='Seed')

    parser = defaults.add_args_to_parser(defaults.CONFIG_ARGS, parser)
    parser = defaults.add_args_to_parser(defaults.MODEL_LOADER_ARGS, parser)
    parser = defaults.add_args_to_parser(defaults.TRAINING_ARGS, parser)
    parser = defaults.add_args_to_parser(defaults.PGD_ARGS, parser)
    args = parser.parse_args()
    return args

def get_path(args):
    path = os.path.join(args.out_dir)
    print(path)
    input()
    if not os.path.exists(path):
        os.makedirs(path)
    return path

def main():
    args = get_args()
    args = setup_args(args)

    torch.manual_seed(args.seed)
    random.seed(args.seed)
    np.random.seed(args.seed)

    path = get_path(args)

    nb_layers_model2 = args.nb_layers_model2
    if args.nb_layers_model2 is None:
        nb_layers_model2 = args.nb_layers_model1

    all_sim = []
    for layer_num_model1 in range(args.nb_layers_model1):
        sim_aux = []
        for layer_num_model2 in range(nb_layers_model2):
            print(f'Layer m1: {layer_num_model1}, Layer m2: {layer_num_model2}')
            features_path_model1 = os.path.join(args.features_path_model1, f'l{layer_num_model1}.pt')
            features_path_model2 = os.path.join(args.features_path_model2, f'l{layer_num_model2}.pt')

            features_x = torch.load(features_path_model1).type(torch.DoubleTensor)
            features_y = torch.load(features_path_model2).type(torch.DoubleTensor)
            print(f'features_x shape {features_x.shape}')
            print(f'features_y shape {features_y.shape}')

            if args.sim_method == 'cka':
                sim = cka_trace_tensor(features_x, features_y)
            elif args.sim_method == 'cos':
                sim_mx = cosine_similarity(features_x, features_y, path=path)
                sim = np.trace(sim_mx) / sim_mx.shape[0]
            else:
                sim_mx = pnka(features_x, features_y, path=None, logger=None)
                sim = np.trace(sim_mx) / sim_mx.shape[0]
            sim_aux.append(sim)
        print(torch.tensor(sim_aux).shape)
        all_sim.append(torch.tensor(sim_aux))
    all_sim = torch.stack(all_sim)
    print(all_sim.shape)
    model1_name = args.features_path_model1.split('/')[-1]
    model2_name = args.features_path_model2.split('/')[-1]
    print(os.path.join(path, f"sim_{model1_name}_{model2_name}.pt"))
    input()
    torch.save(all_sim, os.path.join(path, f"sim_{model1_name}_{model2_name}.pt"))

if __name__=='__main__':
    main()
