import os
import pickle

import numpy as np
import pandas as pd
import torch
from sklearn.decomposition import PCA
from tqdm.auto import tqdm
from torchvision.datasets import ImageFolder

import src.utils as utils
from src.dimensionality import Dimensionality
from src.hooks import FeatureExtractor
from src.models import get_model


PATH_DATASETS = os.path.join(..., 'datasets')
PATH_THINGS = os.path.join(PATH_DATASETS, 'THINGS', 'images', 'classes')
SUBSET_IMAGES = 500
SEED_SUBSET = 7878


def load_train_filenames():
    path_processed = os.path.join(PATH_DATASETS,
                                  'THINGS-fMRI',
                                  'processed')

    path_stimuli_01 = os.path.join(path_processed,
                                   'S1',
                                   'stimuli',
                                   'S1_stimuli.csv')
    path_stimuli_02 = os.path.join(path_processed,
                                   'S2',
                                   'stimuli',
                                   'S2_stimuli.csv')
    path_stimuli_03 = os.path.join(path_processed,
                                   'S3',
                                   'stimuli',
                                   'S3_stimuli.csv')

    stimuli_01 = pd.read_csv(path_stimuli_01)
    stimuli_02 = pd.read_csv(path_stimuli_02)
    stimuli_03 = pd.read_csv(path_stimuli_03)

    is_train_01 = stimuli_01['Split'] == 'train'
    is_train_02 = stimuli_02['Split'] == 'train'
    is_train_03 = stimuli_03['Split'] == 'train'

    filenames_stimuli_01 = \
        np.sort(np.array(stimuli_01[is_train_01]['Stimulus'].to_numpy()))
    filenames_stimuli_02 = \
        np.sort(np.array(stimuli_02[is_train_02]['Stimulus'].to_numpy()))
    filenames_stimuli_03 = \
        np.sort(np.array(stimuli_03[is_train_03]['Stimulus'].to_numpy()))

    assert np.all(filenames_stimuli_01 == filenames_stimuli_02)
    assert np.all(filenames_stimuli_01 == filenames_stimuli_03)

    return filenames_stimuli_01


def main():
    seeds = [1234, 4567, 5678]
    model_names = ['resnet50', 'simclr', 'clip_resnet50']
    relu_layers = \
        ['layer1.0.relu', 'layer1.1.relu', 'layer1.2.relu',
         'layer2.0.relu', 'layer2.1.relu', 'layer2.2.relu', 'layer2.3.relu',
         'layer3.0.relu', 'layer3.1.relu', 'layer3.2.relu', 'layer3.3.relu',
         'layer3.4.relu', 'layer3.5.relu', 'layer4.0.relu', 'layer4.1.relu',
         'layer4.2.relu']

    for i_model, model_name in enumerate(model_names):
        utils.log(f'Extracting embeddings from {model_name}...')
        model = get_model(model_name)
        layer_list = []
        for layer in relu_layers:
            if layer in model.layers:
                layer_list.append(layer)
            elif f'{layer}1' in model.layers:
                layer_list.append(f'{layer}1')
            else:
                raise NotImplementedError
        num_layers = len(layer_list)

        path_models = os.path.join('models')
        os.makedirs(path_models, exist_ok=True)
        path_output = os.path.join(path_models, f'{model_name}_units.pickle')

        if os.path.exists(path_output):
            print(f'Embeddings for model {model_name} already exist!')
            continue
        os.makedirs(path_models, exist_ok=True)

        device = utils.get_device()
        feature_extractor = FeatureExtractor(model)
        feature_extractor.register_hooks(layer_list)

        model.model.to(device)
        model.model.eval()

        # Get training images filenames from one of the subjects
        filenames_train = load_train_filenames()
        num_images = len(filenames_train)
        dataset = \
            ImageFolder(PATH_THINGS,
                        transform=model.preprocess)
        filenames_dataset = \
            np.array([os.path.basename(path) for path, _ in dataset.samples])
        filenames_dataset = np.sort(filenames_dataset)
        indices_images = np.searchsorted(filenames_dataset, filenames_train)
        np.random.seed(SEED_SUBSET)
        np.random.shuffle(indices_images)
        indices_images = indices_images[:SUBSET_IMAGES]

        dict_effective = {}
        dict_intrinsic = {}

        num_subsample = 50
        total_subsample = 100

        for i_layer, layer_name in enumerate(layer_list):
            with torch.no_grad():
                for i_image, idx_image in tqdm(enumerate(indices_images),
                                               total=SUBSET_IMAGES):
                    input, _ = dataset[idx_image]
                    input = input.unsqueeze(0).to(device)
                    _ = model.model(input)

                    if i_image == 0:
                        num_units = \
                            (feature_extractor.activations[layer_name].detach()
                                                                      .cpu()
                                                                      .numpy()
                                                                      .flatten()
                                                                      .shape[0])

                        activations = np.empty((SUBSET_IMAGES, num_units),
                                               dtype=np.float32)
                        print(layer_name, num_units)

                    activations[i_image] = \
                        (feature_extractor.activations[layer_name].detach()
                                                                  .cpu()
                                                                  .numpy()
                                                                  .flatten())

            np.random.seed(seeds[i_model])
            num_features = activations.shape[1]

            if num_features <= num_subsample:
                utils.log(f'Skipping layer {layer_name} ({num_features})...')
                continue

            list_ED = []
            list_ID = []

            for i_sample in range(total_subsample):
                subset_indices = np.random.choice(num_features,
                                                  num_subsample,
                                                  replace=False)

                ED = Dimensionality.effective_dimensionality(activations[:, subset_indices])
                ID = Dimensionality.intrinsic_dimensionality(activations[:, subset_indices])

                list_ED.append(ED)
                list_ID.append(ID)

            dict_effective[layer_name] = list_ED
            dict_intrinsic[layer_name] = list_ID

        utils.log(f'Layer {layer_name} ({i_layer+1}/{num_layers}) completed!')

        dict_results = {
            'effective': dict_effective,
            'intrinsic': dict_intrinsic,
            'filenames': filenames_train,
            'labels': dataset.classes,
            'layers': model.layers
        }

        utils.log(f'Saving embeddings from {model_name}...\n')
        with open(path_output, "wb") as fp:
            pickle.dump(dict_results, fp)

if __name__ == '__main__':
    main()
