import argparse
import numpy as np
import pandas as pd
import torch
import os.path
from loguru import logger

from univ.utils import similarity as sim
from univ.utils import load_data as ld
from univ.utils import measures as ms
from univ.utils import model_import as mi
from univ.utils import datasets as ds
from univ.utils import sampling as sa
from univ.utils.attacker import AttackerModel

from typing import List, Dict, Any

MEASURES = ['mag', 'con', 'cos', 'jacc', 'rank', 'rank_jacc', 'proc', 'shape', 'cka']
MEASURE_CHOICES = MEASURES + ['all']
MEASURE_NAMES = {
    'mag': 'mag',          # Magnitude
    'con': 'con',          # Concentricity
    'cos': 'cos_sim',      # 2nd order cosine similarity
    'jacc': 'jac',         # Jaccard similarity
    'rank': 'rank',        # Rank similarity
    'rank_jacc': 'joint',  # Joint k-NN Jaccard and rank similarity
    'proc': 'proc',        # Orthogonal Procrustes
    'shape': 'shape',      # Shape metric
    'cka': 'cka'           # Centered kernel alignment
}
MEASURE_FOLDERS = {
    'mag': 'magnitude/',
    'con': 'concentricity/',
    'cos': 'cos_sim/',
    'jacc': 'jaccard/',
    'rank': 'rank/',
    'rank_jacc': 'joint/',
    'proc': 'procrustes/',
    'shape': 'shape/',
    'cka': 'cka/'
}
BASE_DIR = './'


parser = argparse.ArgumentParser(description='Calculate representational similarity on standard or inverted images.')
parser.add_argument('-a', '--alpha',
                    help='Alpha value for the shift shape metric',
                    type=float,
                    default=1)
parser.add_argument('-b', '--batch',
                    help='Batch size',
                    type=int,
                    default=64)
parser.add_argument('-c', '--center',
                    help='Mean-center representations before calculating similarity',
                    choices=[0, 1],
                    type=int,
                    default=1)
parser.add_argument('-d', '--dataset',
                    help='The dataset to use',
                    choices=['sat6', 'imagenet', 'cifar10', 'snli', "imagenet100", "imagenet50"],
                    type=str,
                    default='imagenet')
parser.add_argument('-e', '--exp',
                    help='The experiment number',
                    type=int,
                    default=0)
parser.add_argument('-f', '--function',
                    help='Similarity function to use for nearest neighbor calculation',
                    choices=['euc', 'cos_sim'],
                    type=str,
                    default='cos_sim')
parser.add_argument('-i', '--inv',
                    help='Compute similarity on inverted images or textual adversarial examples',
                    choices=[0, 1],
                    type=int,
                    default=0)
parser.add_argument('-k', '--knn',
                    help='The number of nearest neighbors to consider',
                    type=int,
                    default=500)
parser.add_argument('-m', '--models',
                    help='The pre-trained models to compare',
                    choices=['imagenet', 'cifar10', 'snli', "imagenet100", "imagenet50"],
                    type=str,
                    default='imagenet')
parser.add_argument('-n', '--norm',
                    help='Normalize activations for Procrustes',
                    choices=[0, 1],
                    type=int,
                    default=1)
parser.add_argument('-o', '--eps',
                    help='Epsilon for adversarial training',
                    type=str,
                    default='eps3')
parser.add_argument('-p', '--perturbs',
                    help='Number of times activations should be shuffled to calculate a baseline',
                    type=int,
                    default=0)
parser.add_argument('-r', '--rep',
                    help='Representational similarity measure to compute, defaults to all',
                    choices=MEASURE_CHOICES,
                    type=str,
                    default='all')
parser.add_argument('-s', '--sample',
                    help='Sample inverted images or adversarial examples',
                    action="store_true",
                    default=False)
parser.add_argument('-v', '--adv',
                    help='Compare models on ImageNet adversarial examples',
                    choices=[0, 1],
                    type=int,
                    default=0)
parser.add_argument("--model-dir",
                    default=None,
                    required=True,
                    help="Path to directory that contains all checkpoints of models of interest")
parser.add_argument("--index-file-dir", default=None)
parser.add_argument("--inverted-imgs-dir", required=True)
parser.add_argument("--config-file-path", default=None)

def get_activations(
        models: List[AttackerModel],
        layers: List[str],
        device: torch.device,
        dataset_name: str,
        inverted_imgs_dir: str,
        dataset_attr: dict,
        epsilon: str,
        model_names: List[str],
        use_inverted_imgs_or_adv_examples: bool,
        use_imagenet_adv_examples: bool,
        path_to_index_file: str,
        pool_images_across_models: bool,
    ):
    if use_inverted_imgs_or_adv_examples or use_imagenet_adv_examples:
        if not epsilon.startswith("eps"):
            raise ValueError(f"epsilon should start with 'eps' and look like, e.g., 'eps025', but is {epsilon}")
        inverted_imgs = ld.load_inverted_data(model_names, inverted_imgs_dir, dataset_name, epsilon)
        inverted_imgs, _ = sa.sample_data(
            inverted_imgs,
            path_to_index_file,
            use_imagenet_adv_examples,
            pool_images_across_models,
            dataset_name,
        )
        activations = sim.inverted_activations(
            dataset_attr['dict_paths'],  # only relevant for SNLI models
            model_names,
            models,
            layers,
            inverted_imgs,
            device,
            dataset_name,
        )
    else:
        # TODO
        logger.info("Entering code path that was not refactored. Might have unexpected dependencies on command line args.")
        use_cifar = args.models == 'cifar10'
        _, _, dataloader = ld.get_data(dataset_name, dataset_attr['labels_path'], dataset_attr['data_path'], path_to_index_file,
                                       dataset_attr['dict_paths'], args.batch, use_cifar=use_cifar)
        activations = sim.get_activations(models, layers, dataloader, device)
    return activations


def get_statistics(activations: dict, measure: str, model_names: list, path: str, eps: str, stats_type: str = ''):
    if bool(args.inv) or bool(args.adv):
        stats = ms.get_inverted_rep_sim_statistics(activations, measure)
        df = pd.DataFrame(np.array(stats).reshape((len(model_names), len(model_names))), columns=model_names,
                          index=model_names)
    else:
        stats = ms.get_rep_sim_statistics(activations, measure)
        df = pd.DataFrame(np.array(stats), index=model_names)
    ms.save_dataframe_to_csv(df, model_names, f'{path}{measure}_{stats_type}{eps}.csv')


def get_rep_sim(
    activations: dict,
    measure: str,
    model_names: list,
    path: str,
    eps: str,
    device: torch.device,
    sim_type: str,
    normalize_acts_procrustes: bool,
    center_acts: bool,
    baseline_shuffles: int,
    use_inverted_imgs_or_adv_examples: bool,
    use_imagenet_adv_examples: bool,
    knn: int,
    nn_simfunc: str,
):
    shape = (len(model_names), len(model_names))
    pre = ''
    if measure == 'proc':
        pre += 'norm_' if normalize_acts_procrustes else ''
    if measure != 'cka':
        pre += 'mean_' if center_acts else ''
    if use_inverted_imgs_or_adv_examples or use_imagenet_adv_examples:
        if measure == 'cka':
            sims, sims_base = sim.row_cka(activations, device, baseline_shuffles > 0, baseline_shuffles)
        else:
            sims, sims_base = ms.get_inverted_rep_sim(activations, device, center_columns=center_acts, k=knn,
                                                      sim_funct=nn_simfunc, measure=measure,
                                                      permute=baseline_shuffles > 0, n_permutations=baseline_shuffles,
                                                      use_norm=normalize_acts_procrustes)
    else:
        if measure == 'cka':
            sims, sims_base = sim.pairwise_cka(activations, device, baseline_shuffles > 0, baseline_shuffles)
        else:
            sims, sims_base = ms.get_rep_sim(activations, device, center_columns=center_acts, k=knn,
                                             sim_funct=nn_simfunc, measure=measure, permute=baseline_shuffles > 0,
                                             n_permutations=baseline_shuffles, use_norm=normalize_acts_procrustes)
    sims = np.array([x for x in sims if x is not None])
    sims_base = [x for x in sims_base if x is not None]
    df = pd.DataFrame(sims.reshape(shape), columns=model_names, index=model_names)
    ms.save_dataframe_to_csv(df, model_names, f'{path}{MEASURE_NAMES[measure]}_{pre}{sim_type}{eps}.csv')
    if len(sims_base) > 0:
        ms.save_baseline_scores(sims_base, f'{path}{MEASURE_NAMES[measure]}_{pre}{sim_type}all_{eps}.csv',
                                model_names, num_perturbations=baseline_shuffles)


def calculate_similarity(
        acts: Dict[int, Any],
        measure: str,
        model_names: list,
        path: str,
        eps: str,
        device: torch.device,
        sim_type: str,
        normalize_acts_procrustes: bool,
        center_acts: bool,
        baseline_shuffles: int,
        use_inverted_imgs_or_adv_examples: bool,
        use_imagenet_adv_examples: bool,
        knn: int,
        nn_simfunc: str,
    ):
    if measure in ['mag', 'con']:
        print(f'Calculating {measure}')
        get_statistics(acts, measure, model_names, path, f'{eps}_{args.exp}', sim_type)
    else:
        get_rep_sim(
            acts,
            measure,
            model_names,
            path,
            f'{eps}_{args.exp}',
            device,
            sim_type,
            normalize_acts_procrustes,
            center_acts,
            baseline_shuffles,
            use_inverted_imgs_or_adv_examples,
            use_imagenet_adv_examples,
            knn,
            nn_simfunc,
        )


def get_save_path(results_dir: str, measure: str):
    folder = 'inverted/' if bool(args.inv) or bool(args.adv) else 'standard/'
    add_path = f'{args.models}/' if args.dataset == 'sat6' else ''

    if measure in ['jacc', 'rank', 'rank_jacc']:
        path = f'{results_dir}{MEASURE_FOLDERS[measure]}{folder}{args.function}/{args.dataset}/{add_path}'
    elif measure == 'cka':
        path = f'{results_dir}{MEASURE_FOLDERS[measure]}{folder}10000/{args.dataset}/{add_path}'
    else:
        path = f'{results_dir}{MEASURE_FOLDERS[measure]}{folder}{args.dataset}/{add_path}'

    return path


def main():
    sa.check_dataset_models(args.dataset, args.models)
    assert args.perturbs >= 0 and args.batch > 0
    assert 0 < args.knn < 10000
    dataset = args.dataset + '/'
    dataset_attr = ds.get_dataset_attr(dataset, args.models + '/', BASE_DIR, args.eps)
    args.adv = args.adv if args.dataset == 'imagenet' else False

    model_names = dataset_attr['model_names']
    layers = dataset_attr['layers']
    four_dim_models = dataset_attr['four_dim_models']
    eps = dataset_attr['eps']
    exp_number = args.exp
    perturbs = args.perturbs
    sim_type = 'adv_' if bool(args.adv) else ''
    sim_type += 'sample_' if bool(args.sample) else ''
    sim_type += 'base_' if perturbs > 0 else ''

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    models = mi.load_models(model_names, args.models + '/', args.model_dir, device)

    if args.index_file_dir:
        idx_path = os.path.join(args.index_file_dir, f"{dataset}indices_{exp_number}.csv")
    else:
        idx_path = f'{dataset_attr["results_dir"]}magnitude/standard/{dataset}indices_{exp_number}.csv'

    activations = get_activations(
        models=models,
        model_names=model_names,
        layers=layers,
        device=device,
        dataset_name=args.dataset,
        dataset_attr=dataset_attr,
        epsilon=args.eps,
        inverted_imgs_dir=args.inverted_imgs_dir,
        path_to_index_file=idx_path,
        use_inverted_imgs_or_adv_examples=args.inv,
        use_imagenet_adv_examples=args.adv,
        pool_images_across_models=args.sample,
        )

    if args.rep == 'all':
        for measure in MEASURES:
            path = get_save_path(dataset_attr['results_dir'], measure)
            if not os.path.exists(path):
                logger.info(f"Creating directory for results: {path}")
                os.makedirs(path)
            models = four_dim_models if measure == 'shape' else model_names
            calculate_similarity(
                activations,
                measure,
                models,
                path,
                eps,
                device,
                sim_type,
                args.norm,
                args.center,
                args.perturbs,
                bool(args.inv),
                bool(args.adv),
                args.knn,
                args.function,
            )
    else:
        path = get_save_path(dataset_attr['results_dir'], args.rep)
        if not os.path.exists(path):
            logger.info(f"Creating directory for results: {path}")
            os.makedirs(path)
        models = four_dim_models if args.rep == 'shape' else model_names
        calculate_similarity(
            activations,
            args.rep,
            models,
            path,
            eps,
            device,
            sim_type,
            args.norm,
            args.center,
            args.perturbs,
            bool(args.inv),
            bool(args.adv),
            args.knn,
            args.function,
        )

    if args.config_file_path is not None:
        import json
        if os.path.exists(args.config_file_path):
            logger.info(f"Config already exists under {args.config_file_path}. Only printing output:")
            logger.info(args.__dict__)
        else:
            with open(args.config_file_path, "w") as f:
                json.dump(args.__dict__, f)

if __name__ == '__main__':
    args = parser.parse_args()
    main()
