'''Computes similarity matrix'''
import os
import sys
import argparse
import random
import time
import torch
import logging
import jax
import numpy as np

from sim_measures.similarity_methods import cka_notrace, cosine_similarity, l2_distance, pnka, efficient_pnka, efficient_cka
from utils.helper import create_logger

from robustness import defaults
from robustness.datasets import DATASETS
from robustness.main import setup_args


def get_args():
    parser = argparse.ArgumentParser()

    parser.add_argument('--sim_method', required=False, default='cka',
        choices=['cka', 'pnka', 'efficient_pnka', 'efficient_cka', 'cos', 'l2'],
        type=str, help='Similarity method to be used.')
    parser.add_argument('--out-dir-complement', default=None, type=str,
        help='Folder name to be added to the out_dir. If None, will be automatically generated.')

    parser.add_argument('--task', default='cv', type=str, choices=['cv', 'nlp'],
        help='Which layer to get the representation from model')
    
    parser.add_argument('--features_path_model1', required=True, type=str, nargs='+', help='Path to save features of model 1')
    parser.add_argument('--features_path_model2', required=True, type=str, nargs='+', help='Path to save features of model 2')
    parser.add_argument('--idx_model1_name_path', required=False, type=int, default=-2, help='Where model1 name is in the path')
    parser.add_argument('--idx_model2_name_path', required=False, type=int, default=-2, help='Where model2 name is in the path')

    parser.add_argument('--index_features_path', required=False, default=None, type=str,
                        help='If not None, will select features based on the index file passed as argument.')

    parser.add_argument('--nb_replicate', required=False, default=None, type=int,
        help='How many points to replicate from features_x to features_y.')
    parser.add_argument('--make_all_equal', required=False, default=False, action='store_true',
        help='If True, makes X=Y for all points, i.e., x1=x2=xn=y1=y2=yn. \
              If False, makes x1=y1, x2=y2, xn=yn, x1!=x2, x2!=x2.')
    parser.add_argument('--operation_perturbed', required=False, default=None, type=str,
        choices=['substitute', 'concatenate', 'remove_dimension', 'remove_duplicates'],
        help='What do do with perturbed points (only activated if nb_perturbed is not None).')
    
    # LANDMARKS
    parser.add_argument('--remove_landmarks_from_features', default=False, action='store_true',
        help='Whether to use landmarks in features (False) or remove them (True)')
    parser.add_argument('--landmarks_path_model1', default=None, type=str, help='Path to landmarks of model 1')
    parser.add_argument('--landmarks_path_model2', default=None, type=str, help='Path to landmarks of model 2')
    parser.add_argument('--landmarks_split', required=False, default='train', type=str,
        choices=['train', 'test'], help='Which split to get the landmarks from.')
    parser.add_argument('--nb_landmarks', required=False, default=None, type=int,
        help='How many landmarks to use. If None, no use of landmarks.')
    parser.add_argument('--landmark_indexes_file', required=False, default=None, type=str,
        help='Landmark indexes file [optional].')

    parser.add_argument('--nb_perturbed', required=False, default=None, type=int,
        help='How many perturbed images to use. Should be <= nb_instances - nb_landmarks.')
    parser.add_argument('--perturb_method', required=False, default=None, type=str,
        choices=['adv', 'blur', 'noise', 'elastic'], help='Which perturbation to add to the images.')
    parser.add_argument('--attacked_params', required=False, default=None, type=str,
        help='Name and paramters of the model that was attacked to generate the images.')

    parser.add_argument('--seed', default=0, type=int, help='Seed')
    parser.add_argument('--ef_batch_size', default=None, type=int, help='BS for optimized version of pnka')

    parser = defaults.add_args_to_parser(defaults.CONFIG_ARGS, parser)
    parser = defaults.add_args_to_parser(defaults.MODEL_LOADER_ARGS, parser)
    parser = defaults.add_args_to_parser(defaults.TRAINING_ARGS, parser)
    parser = defaults.add_args_to_parser(defaults.PGD_ARGS, parser)
    args = parser.parse_args()
    return args

def get_path(args):
    if args.out_dir_complement is None:
        landmarks_path = ''
        if args.nb_landmarks is not None:
            landmarks_from_train = ''
            if args.landmarks_path_model1 is not None:
                if 'train' in args.landmarks_path_model1:
                    landmarks_from_train = 'train'
            landmark_indexes_file = ''
            if args.landmark_indexes_file is not None:
                landmark_indexes_file = args.landmark_indexes_file.split('/')[-1].split('.')[0]
                if 'train' in args.landmark_indexes_file:
                    landmarks_from_train = 'train'
            landmarks_dataset = ''
            idx_representations_string_landmarks = args.landmarks_path_model1.split('/').index('representations')
            dataset_in_landmarks_path = args.landmarks_path_model1.split('/')[idx_representations_string_landmarks+1].split('-')[0]
            idx_representations_string_features = args.features_path_model1[0].split('/').index('representations')
            dataset_in_features_path = args.features_path_model1[0].split('/')[idx_representations_string_features+1].split('-')[0]
            if dataset_in_landmarks_path != dataset_in_features_path:
                landmarks_dataset = dataset_in_landmarks_path
            landmarks_path = os.path.join(
                'landmarks', landmarks_dataset, landmarks_from_train, landmark_indexes_file, f'{args.nb_landmarks}_seed{str(args.seed)}')

        operation_perturbed_path = ''
        if args.operation_perturbed == 'concatenate':
            operation_perturbed_path = args.operation_perturbed

        perturb_path = ''
        if args.nb_perturbed is not None:
            if args.perturb_method == 'adv':
                perturb_path = os.path.join('adv', f'{args.nb_perturbed}_adv', args.attacked_params)
            elif args.perturb_method == 'noise':
                perturb_path = os.path.join('noise', f'{args.nb_perturbed}_noise')
            elif args.perturb_method == 'elastic':
                perturb_path = os.path.join('elastic', f'{args.nb_perturbed}_elastic')
            elif args.perturb_method == 'blur': # blur
                perturb_path = os.path.join('blur', f'{args.nb_perturbed}_blur')
            elif args.operation_perturbed == 'remove_dimension':
                perturb_path = os.path.join(args.operation_perturbed, f'{args.nb_perturbed}')
        
        if args.task == 'nlp':
            model1_name = args.features_path_model1.split('/')[4]
            model1_properties = '-'.join(args.features_path_model1.split('/')[-1].split('_')[:-1]).replace('-all', '').replace('all', '')
            model2_name = args.features_path_model2.split('/')[4]
            model2_properties = '-'.join(args.features_path_model2.split('/')[-1].split('_')[:-1]).replace('-all', '').replace('all', '')
            dataset =  args.features_path_model1.split('/')[5].replace('_', '')
            print(model1_name, model1_properties)
            print(model2_name, model2_properties)
            emb_operation = ''
            # if 'norm' in args.features_path_model1.split('/')[-1]:
            #     emb_operation = args.features_path_model1.split('/')[-2]
            if model1_properties == '' and model2_properties == '':
                foldername = f'M1_{model1_name}_M2_{model2_name}_{dataset}'
            elif model1_properties == '':
                foldername = f'M1_{model1_name}_M2_{model2_properties}-{model2_name}_{dataset}'
            elif model2_properties == '':
                foldername = f'M1_{model1_properties}-{model1_name}_M2_{model2_name}_{dataset}'
            else:
                foldername = f'M1_{model1_properties}-{model1_name}_M2_{model2_properties}-{model2_name}_{dataset}'
            if model2_name == model1_name:
                foldername = f'M1_{model1_properties}-{model1_name}_{dataset}'
        else: # CV
            if len(args.features_path_model1) == 1: # one element only (most cases)
                model1_name = args.features_path_model1[0].split('/')[args.idx_model1_name_path]
                layer_num1 = args.features_path_model1[0].split('/')[-1].split('.')[0]
                model2_name = args.features_path_model2[0].split('/')[args.idx_model2_name_path]
                layer_num2 = args.features_path_model2[0].split('/')[-1].split('.')[0]

                foldername = f'M1_{model1_name}_{layer_num1}_M2_{model2_name}_{layer_num2}'
                if model2_name == model1_name:
                    foldername = f'M1_{model1_name}_{layer_num1}_{layer_num2}'
            else:
                foldername = ''

        if 'train' in args.features_path_model1:
            foldername = os.path.join(foldername, 'train')

        if 'tabular' in args.features_path_model1:
            foldername = os.path.join('tabular', args.features_path_model1.split('/')[-2],
                f"M1_{args.features_path_model1.split('/')[-1].split('.')[0]}_M2_{args.features_path_model2.split('/')[-1].split('.')[0]}")
        path = os.path.join(args.out_dir, foldername, landmarks_path, perturb_path, operation_perturbed_path)
    else:
        path = os.path.join(args.out_dir, args.out_dir_complement)
    print(path)
    input()
    if not os.path.exists(path):
        os.makedirs(path)
    return path


def replicate_features(args, features_x, features_y, logger):
    logger.info(f'=> Replicating {args.nb_replicate} samples from x to y')
    if args.make_all_equal:
        logger.info(f'=> Making all points equal, in X and in Y, respectively.')
        point_tobe_replicated = features_x[0].repeat(args.nb_replicate, 1)
        features_x[:args.nb_replicate] = point_tobe_replicated[:args.nb_replicate]
        features_y[:args.nb_replicate] = point_tobe_replicated[:args.nb_replicate]
        logger.info(f'Replicated features x and y are equal within themselves \
            {torch.equal(features_x[:args.nb_replicate], features_y[:args.nb_replicate])}')
        logger.info(f'Replicated features y are equal within themselves \
            {torch.equal(features_y[:args.nb_replicate], features_y[:args.nb_replicate])}')
        logger.info(f'Replicated features x are equal within themselves \
            {torch.equal(features_x[:args.nb_replicate], features_x[:args.nb_replicate])}')
    else: # default
        features_y[:args.nb_replicate] = features_x[:args.nb_replicate]
    return features_x, features_y


def get_landmarks(args, features_x, features_y, path, logger):
    nb_instances = features_x.shape[0]
    indexes_other_points = list(np.arange(nb_instances))

    if args.landmark_indexes_file is not None:
        landmark_indexes = torch.tensor(torch.load(args.landmark_indexes_file)[:args.nb_landmarks])
        logger.info(f'=> Loading landmark_indexes from {args.landmark_indexes_file}')
    else:
        logger.info(f'=> Randomly selecting {args.nb_landmarks} points as landmarks')
        landmark_indexes = torch.tensor(random.sample(range(0, nb_instances), int(args.nb_landmarks)))

    torch.save(landmark_indexes.cpu().detach(), os.path.join(path, "landmark_indexes.pt"))
    logger.info(f'landmark_indexes shape {landmark_indexes.shape}')

    if args.landmarks_path_model1 is None:
        landmarks_x = torch.index_select(features_x, 0, landmark_indexes)
    elif args.landmarks_path_model1 is not None:
        logger.info(f'=> Loading landmarks from {args.landmarks_path_model1}')
        landmarks_x = torch.load(args.landmarks_path_model1).type(torch.DoubleTensor)
        landmarks_x = torch.index_select(landmarks_x, 0, landmark_indexes)

    if args.landmarks_path_model2 is None:
        landmarks_y = torch.index_select(features_y, 0, landmark_indexes)
    elif args.landmarks_path_model2 is not None:
        logger.info(f'=> Loading landmarks from {args.landmarks_path_model2}')
        landmarks_y = torch.load(args.landmarks_path_model2).type(torch.DoubleTensor)
        landmarks_y = torch.index_select(landmarks_y, 0, landmark_indexes)

    indexes_other_points = list(np.delete(list(np.arange(nb_instances)), landmark_indexes))

    return landmarks_x, landmarks_y, indexes_other_points


def perturb_features(args, features_x, features_y, path, indexes_other_points, logger=None):
    nb_instances = features_x.shape[0]
    if args.nb_landmarks is not None:
        if args.nb_perturbed > nb_instances:
            print(f'Number of adversarially perturbed instances ({args.nb_perturbed}) should be lower or equal to {(nb_instances - args.nb_landmarks)}')
            exit()
    logger.info(f'=> Adding {args.nb_perturbed} perturbed instances')
    logger.info(f'Perturbation method chosen: {args.perturb_method}')

    if indexes_other_points is None:
        perturbed_indexes = torch.tensor(random.sample(list(np.arange(nb_instances)), args.nb_perturbed))
    else:
        perturbed_indexes = torch.tensor(random.sample(indexes_other_points, args.nb_perturbed))

    torch.save(perturbed_indexes.cpu().detach(), os.path.join(path, "perturbed_indexes.pt"))
    logger.info(f'perturbed_indexes shape {perturbed_indexes.shape}')

    attacked_params_foldername = ''
    if args.perturb_method == 'adv':
        attacked_params_foldername = args.attacked_params

    if args.perturb_method is not None:
        perturbed_features_path_model1 = os.path.join(
            '/'.join(args.features_path_model1[0].split('/')[:-1]), 
            args.perturb_method, attacked_params_foldername,
            args.features_path_model1[0].split('/')[-1])
        perturbed_features_path_model2 = os.path.join(
            '/'.join(args.features_path_model2[0].split('/')[:-1]),
            args.perturb_method, attacked_params_foldername,
            args.features_path_model2[0].split('/')[-1])

        logger.info(f'Loading perturbed images for features_x from {perturbed_features_path_model1}')
        logger.info(f'Loading perturbed images for features_y from {perturbed_features_path_model2}')
        perturbed_features_x = torch.load(perturbed_features_path_model1)
        perturbed_features_y = torch.load(perturbed_features_path_model2)

    logger.info(f'Adding perturbed instances in original features_x and features_y')
    if args.operation_perturbed == 'substitute':
        for perturbed_idx in perturbed_indexes:
            features_x[int(perturbed_idx)] = perturbed_features_x[int(perturbed_idx)]
            features_y[int(perturbed_idx)] = perturbed_features_y[int(perturbed_idx)]
    elif args.operation_perturbed == 'concatenate':
        selected_perturbed_features_x = torch.index_select(perturbed_features_x, 0, perturbed_indexes)
        selected_perturbed_features_y = torch.index_select(perturbed_features_y, 0, perturbed_indexes)
        features_x = torch.cat((features_x, selected_perturbed_features_x), 0)
        features_y = torch.cat((features_y, selected_perturbed_features_y), 0)
    elif args.operation_perturbed == 'remove_dimension':
        # remove_neuron_indexes = torch.tensor(random.sample(list(np.arange(features_y.shape[1])) , args.nb_perturbed))
        remove_neuron_indexes = torch.tensor(int(args.seed))
        keep_indexes = torch.from_numpy(np.delete(np.arange(features_y.shape[1]), remove_neuron_indexes))
        features_y = torch.index_select(features_y, 1, keep_indexes)
        path = os.path.join(path, 'all', f'neuron_{args.seed}')
        if not os.path.exists(path):
            os.makedirs(path)
        # torch.save(remove_neuron_indexes.cpu().detach(), os.path.join(path, "remove_neuron_indexes.pt"))
    elif args.operation_perturbed == 'remove_duplicates':
        features_x = torch.unique(features_x)
        features_y = torch.unique(features_y)
        path = os.path.join(path, 'remove_duplicates')
        if not os.path.exists(path):
            os.makedirs(path)
    return features_x, features_y, path


def load_features(path):
    if isinstance(path, list):
        features = []
        for features_path in path: # multimodal
            if 'BRAIN' in features_path:
                jax_features = jax.numpy.load(features_path)
                features.append(torch.tensor(np.array(jax_features, dtype=np.float64)))
            else:
                features.append(torch.load(features_path))
        features = torch.cat(features, 1)
    else:
        if 'BRAIN' in path: # multimodal
            jax_features = jax.numpy.load(features_path)
            features = torch.tensor(np.array(jax_features, dtype=np.float64))
        else:
            features = torch.load(path)
    if torch.is_tensor(features):
        features = features.type(torch.DoubleTensor)
    return features

def main():
    args = get_args()
    args = setup_args(args)

    torch.manual_seed(args.seed)
    random.seed(args.seed)
    np.random.seed(args.seed)

    path = get_path(args)
    logger = create_logger(path)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    features_x = load_features(args.features_path_model1)
    features_y = load_features(args.features_path_model2)
    logger.info(f'features_x shape {features_x.shape}')
    logger.info(f'features_y shape {features_y.shape}')

    if 'tabular' in args.features_path_model1 and not torch.is_tensor(features_x):
        features_x = torch.from_numpy(features_x.features).type(torch.DoubleTensor)
    if 'tabular' in args.features_path_model2 and not torch.is_tensor(features_y):
        features_y = torch.from_numpy(features_y.features).type(torch.DoubleTensor)
    if args.index_features_path is not None:
        print(f'Selecting features based on file {args.index_features_path}')
        index_features = torch.load(args.index_features_path)
        features_x = torch.index_select(features_x, 0, torch.tensor(index_features))
        features_y = torch.index_select(features_y, 0, torch.tensor(index_features))

    if args.nb_replicate is not None:
        features_x, features_y = replicate_features(args, features_x, features_y, logger)

    landmarks_x, landmarks_y = features_x, features_y
    indexes_other_points = None
    if args.nb_landmarks is not None:
        landmarks_x, landmarks_y, indexes_other_points = get_landmarks(
            args, features_x, features_y, path, logger)

    if args.remove_landmarks_from_features and indexes_other_points is not None:
        logger.info('Removing landmarks from features...')
        features_x = torch.index_select(features_x, 0, torch.tensor(indexes_other_points))
        features_y = torch.index_select(features_y, 0, torch.tensor(indexes_other_points))
        indexes_other_points = None
    logger.info(f'New shapes: features_x {features_x.shape}, features_y {features_y.shape}')

    perturbed_indexes = None
    if args.nb_perturbed is not None:
        features_x, features_y, path = perturb_features(args, features_x, features_y, path, indexes_other_points, logger)

    print(landmarks_x.shape, landmarks_y.shape)
    logger.info(f'{features_x.shape}, {features_y.shape}')
    logger.info(f'Similarity method chosen: {args.sim_method}')
    if args.sim_method == 'cka':
        if args.nb_landmarks is not None and args.landmarks_path_model1 is None and args.landmarks_path_model2 is None:
            cka_notrace(features_x, features_y, remove=indexes_other_points, path=path, logger=logger)
        else:
            cka_notrace(features_x, features_y, remove=None, path=path, logger=logger)
    elif args.sim_method == 'pnka':
        pnka(features_x, features_y, path=path, logger=logger)
    elif args.sim_method == 'efficient_pnka':
        efficient_pnka(
            features_x, features_y, args.ef_batch_size, device,
            landmarks_x=landmarks_x, landmarks_y=landmarks_y,
            path=path, logger=logger)
    elif args.sim_method == 'efficient_cka':
        efficient_cka(
            features_x, features_y, args.ef_batch_size, device,
            landmarks_x=landmarks_x, landmarks_y=landmarks_y,
            path=path, logger=logger)
    elif args.sim_method == 'cos':
        cosine_similarity(features_x, features_y, path=path)
    elif args.sim_method == 'l2':
        l2_distance(features_x, features_y, path=path)
    # elif args.sim_method == 'cos_cosxx_cosyy':
    #     cos_cosxx_cosyy(features_x, features_y, path=path, logger=logger)

if __name__=='__main__':
    main()
