import os
import pickle

import numpy as np
import pandas as pd
import torch
from sklearn.decomposition import PCA
from tqdm.auto import tqdm
from torchvision.datasets import ImageFolder

import src.utils as utils
from src.dimensionality import Dimensionality
from src.hooks import FeatureExtractor
from src.models import get_model


# PATH_DATASETS = os.path.join(..., 'Documents', 'datasets')
PATH_DATASETS = os.path.join(..., 'datasets')
PATH_THINGS = os.path.join(PATH_DATASETS, 'THINGS', 'images', 'classes')


def reset_weights(model):
    for name, layer in model.named_modules():
        if hasattr(layer, 'reset_parameters'):
            # Save original parameters
            before = {k: v.clone() for k, v in layer.state_dict().items()}

            # Reset
            layer.reset_parameters()

            # Compare
            after = layer.state_dict()
            changed = any(not torch.equal(before[k], after[k]) for k in before)

            # Assert at least one param changed
            assert changed, f"Layer {name} reset_parameters() did not change any weights!"


def load_train_filenames():
    path_processed = os.path.join(PATH_DATASETS,
                                  'THINGS-fMRI',
                                  'processed')

    path_stimuli_01 = os.path.join(path_processed,
                                   'S1',
                                   'stimuli',
                                   'S1_stimuli.csv')
    path_stimuli_02 = os.path.join(path_processed,
                                   'S2',
                                   'stimuli',
                                   'S2_stimuli.csv')
    path_stimuli_03 = os.path.join(path_processed,
                                   'S3',
                                   'stimuli',
                                   'S3_stimuli.csv')

    stimuli_01 = pd.read_csv(path_stimuli_01)
    stimuli_02 = pd.read_csv(path_stimuli_02)
    stimuli_03 = pd.read_csv(path_stimuli_03)

    is_train_01 = stimuli_01['Split'] == 'train'
    is_train_02 = stimuli_02['Split'] == 'train'
    is_train_03 = stimuli_03['Split'] == 'train'

    filenames_stimuli_01 = \
        np.sort(np.array(stimuli_01[is_train_01]['Stimulus'].to_numpy()))
    filenames_stimuli_02 = \
        np.sort(np.array(stimuli_02[is_train_02]['Stimulus'].to_numpy()))
    filenames_stimuli_03 = \
        np.sort(np.array(stimuli_03[is_train_03]['Stimulus'].to_numpy()))

    assert np.all(filenames_stimuli_01 == filenames_stimuli_02)
    assert np.all(filenames_stimuli_01 == filenames_stimuli_03)

    return filenames_stimuli_01


def main():
    model_names = ['resnet50', 'simclr', 'clip_resnet50']
    relu_layers = \
        ['layer1.0.relu', 'layer1.1.relu', 'layer1.2.relu',
         'layer2.0.relu', 'layer2.1.relu', 'layer2.2.relu', 'layer2.3.relu',
         'layer3.0.relu', 'layer3.1.relu', 'layer3.2.relu', 'layer3.3.relu',
         'layer3.4.relu', 'layer3.5.relu', 'layer4.0.relu', 'layer4.1.relu',
         'layer4.2.relu']

    for i_model, model_name in enumerate(model_names):
        utils.log(f'Extracting embeddings from {model_name}...')
        model = get_model(model_name)
        layer_list = []
        for layer in relu_layers:
            if layer in model.layers:
                layer_list.append(layer)
            elif f'{layer}1' in model.layers:
                layer_list.append(f'{layer}1')
            else:
                raise NotImplementedError
        num_layers = len(layer_list)

        path_models = os.path.join('models')
        os.makedirs(path_models, exist_ok=True)
        path_output = os.path.join(path_models, f'{model_name}_activations_untrained.pickle')

        if os.path.exists(path_output):
            print(f'Embeddings for model {model_name} already exist!')
            continue
        os.makedirs(path_models, exist_ok=True)

        device = utils.get_device()
        feature_extractor = FeatureExtractor(model)
        feature_extractor.register_hooks(layer_list)

        reset_weights(model.model)
        model.model.to(device)
        model.model.eval()

        # Get training images filenames from one of the subjects
        filenames_train = load_train_filenames()
        num_images = len(filenames_train)
        dataset = \
            ImageFolder(PATH_THINGS,
                        transform=model.preprocess)
        filenames_dataset = \
            np.array([os.path.basename(path) for path, _ in dataset.samples])
        filenames_dataset = np.sort(filenames_dataset)
        indices_images = np.searchsorted(filenames_dataset, filenames_train)

        dict_activations = {}

        with torch.no_grad():
            for i_image, idx_image in tqdm(enumerate(indices_images),
                                           total=num_images):
                input, _ = dataset[idx_image]
                input = input.unsqueeze(0).to(device)
                _ = model.model(input)

                if i_image == 0:
                    for layer_name in feature_extractor.layers:
                        num_units = \
                            (feature_extractor.activations[layer_name].detach()
                                                                      .cpu()
                                                                      .numpy()
                                                                      .mean(axis=(2, 3))
                                                                      .flatten()
                                                                      .shape[0])
                        dict_activations[layer_name] = \
                            np.empty((num_images, num_units), dtype=np.float32)
                        print(layer_name, num_units)

                for layer_name in feature_extractor.layers:
                    dict_activations[layer_name][i_image] = \
                        (feature_extractor.activations[layer_name].detach()
                                                                  .cpu()
                                                                  .numpy()
                                                                  .mean(axis=(2, 3))
                                                                  .flatten())

        dict_results = {
            'activations': dict_activations,
            'filenames': filenames_train,
            'labels': dataset.classes,
            'layers': model.layers
        }

        utils.log(f'Saving embeddings from {model_name}...\n')
        with open(path_output, "wb") as fp:
            pickle.dump(dict_results, fp)


if __name__ == '__main__':
    main()
