import os
import pickle

import numpy as np
import pandas as pd
import torch
from tqdm.auto import tqdm
from torchvision.datasets import ImageFolder
from sklearn.random_projection import GaussianRandomProjection

import src.utils as utils
from src.dimensionality import Dimensionality
from src.hooks import FeatureExtractor
from src.models import get_model


PATH_DATASETS = os.path.join(..., 'datasets')
PATH_THINGS = os.path.join(PATH_DATASETS, 'THINGS', 'images', 'classes')
SUBSET_IMAGES = 500
SEED_SUBSET = 7878


def load_train_filenames():
    path_processed = os.path.join(PATH_DATASETS,
                                  'THINGS-fMRI',
                                  'processed')

    path_stimuli_01 = os.path.join(path_processed,
                                   'S1',
                                   'stimuli',
                                   'S1_stimuli.csv')
    path_stimuli_02 = os.path.join(path_processed,
                                   'S2',
                                   'stimuli',
                                   'S2_stimuli.csv')
    path_stimuli_03 = os.path.join(path_processed,
                                   'S3',
                                   'stimuli',
                                   'S3_stimuli.csv')

    stimuli_01 = pd.read_csv(path_stimuli_01)
    stimuli_02 = pd.read_csv(path_stimuli_02)
    stimuli_03 = pd.read_csv(path_stimuli_03)

    is_train_01 = stimuli_01['Split'] == 'train'
    is_train_02 = stimuli_02['Split'] == 'train'
    is_train_03 = stimuli_03['Split'] == 'train'

    filenames_stimuli_01 = \
        np.sort(np.array(stimuli_01[is_train_01]['Stimulus'].to_numpy()))
    filenames_stimuli_02 = \
        np.sort(np.array(stimuli_02[is_train_02]['Stimulus'].to_numpy()))
    filenames_stimuli_03 = \
        np.sort(np.array(stimuli_03[is_train_03]['Stimulus'].to_numpy()))

    assert np.all(filenames_stimuli_01 == filenames_stimuli_02)
    assert np.all(filenames_stimuli_01 == filenames_stimuli_03)

    return filenames_stimuli_01


def main():
    seeds = [1234, 4567, 5678]
    model_names = ['resnet50', 'simclr', 'clip_resnet50']
    relu_layers = \
        ['layer1.0.relu', 'layer1.1.relu', 'layer1.2.relu',
         'layer2.0.relu', 'layer2.1.relu', 'layer2.2.relu', 'layer2.3.relu',
         'layer3.0.relu', 'layer3.1.relu', 'layer3.2.relu', 'layer3.3.relu',
         'layer3.4.relu', 'layer3.5.relu', 'layer4.0.relu', 'layer4.1.relu',
         'layer4.2.relu']

    for i_model, model_name in enumerate(model_names):
        utils.log(f'Extracting embeddings from {model_name}...')
        model = get_model(model_name)
        layer_list = []
        for layer in relu_layers:
            if layer in model.layers:
                layer_list.append(layer)
            elif f'{layer}1' in model.layers:
                layer_list.append(f'{layer}1')
            else:
                raise NotImplementedError
        num_layers = len(layer_list)

        path_models = os.path.join('models')
        os.makedirs(path_models, exist_ok=True)
        path_output = os.path.join(path_models, f'{model_name}_random_projections.pickle')

        if os.path.exists(path_output):
            print(f'Embeddings for model {model_name} already exist!')
            continue
        os.makedirs(path_models, exist_ok=True)

        device = utils.get_device()
        feature_extractor = FeatureExtractor(model)
        feature_extractor.register_hooks(layer_list)

        model.model.to(device)
        model.model.eval()

        # Get training images filenames from one of the subjects
        filenames_train = load_train_filenames()
        num_images = len(filenames_train)
        dataset = \
            ImageFolder(PATH_THINGS,
                        transform=model.preprocess)
        filenames_dataset = \
            np.array([os.path.basename(path) for path, _ in dataset.samples])
        filenames_dataset = np.sort(filenames_dataset)
        indices_images = np.searchsorted(filenames_dataset, filenames_train)
        np.random.seed(SEED_SUBSET)
        np.random.shuffle(indices_images)
        indices_images = indices_images[:SUBSET_IMAGES]

        dict_effective = {}
        dict_intrinsic = {}

        num_subsample = 50
        total_subsample = 100

        for i_layer, layer_name in enumerate(layer_list):
            with torch.no_grad():
                for i_image, idx_image in tqdm(enumerate(indices_images),
                                               total=SUBSET_IMAGES):
                    input, _ = dataset[idx_image]
                    input = input.unsqueeze(0).to(device)
                    _ = model.model(input)

                    if i_image == 0:
                        num_units = \
                            (feature_extractor.activations[layer_name].detach()
                                                                      .cpu()
                                                                      .numpy()
                                                                      .flatten()
                                                                      .shape[0])

                        activations = np.empty((SUBSET_IMAGES, num_units),
                                               dtype=np.float32)
                        print(layer_name, num_units)

                    activations[i_image] = \
                        (feature_extractor.activations[layer_name].detach()
                                                                  .cpu()
                                                                  .numpy()
                                                                  .flatten())

            np.random.seed(seeds[i_model])
            num_features = activations.shape[1]

            if num_features <= num_subsample:
                utils.log(f'Skipping layer {layer_name} ({num_features})...')
                continue

            list_ED = []
            list_ID = []

            for i_sample in range(total_subsample):
                random_projection = GaussianRandomProjection(n_components=num_subsample).fit_transform(activations)

                ED = Dimensionality.effective_dimensionality(random_projection)
                ID = Dimensionality.intrinsic_dimensionality(random_projection)

                list_ED.append(ED)
                list_ID.append(ID)

            dict_effective[layer_name] = list_ED
            dict_intrinsic[layer_name] = list_ID

        utils.log(f'Layer {layer_name} ({i_layer+1}/{num_layers}) completed!')

        dict_results = {
            'effective': dict_effective,
            'intrinsic': dict_intrinsic,
            'filenames': filenames_train,
            'labels': dataset.classes,
            'layers': model.layers
        }

        utils.log(f'Saving embeddings from {model_name}...\n')
        with open(path_output, "wb") as fp:
            pickle.dump(dict_results, fp)

if __name__ == '__main__':
    main()
