import os
import sys

import hydra
import numpy as np
from omegaconf import DictConfig
from sklearn.decomposition import PCA

from constants.paths import _datasets_dir
from get_dataset import get_image_dataset
from image_transforms import ResnetFeatures
from strategies_images import Image_Strategy


def prepare_features(ds_name, split):
    dataset = get_image_dataset(ds_name, res=224, split=split)

    sensitive = dataset.sensitives.numpy().squeeze()
    labels = dataset.labels.numpy().squeeze()

    trns_wrapper = Image_Strategy(ResnetFeatures())
    res_features = (
        trns_wrapper.transform_images(dataset).detach().cpu().numpy().squeeze()
    )
    return res_features, sensitive, labels


def debug(features):
    import IPython
    import matplotlib.pyplot as plt

    pca = PCA()
    _ = pca.fit_transform(features)

    variances = np.cumsum(pca.explained_variance_ratio_)
    plt.plot(variances)

    IPython.embed()
    exit()


@hydra.main(version_base=None)
def main(cfg: DictConfig):
    ds_name = cfg.dataset
    full_features = cfg.full_features
    # splits = ["test", "train"]

    save_dir = os.path.join(_datasets_dir, ds_name + "_features")
    os.makedirs(save_dir, exist_ok=True)

    # dataset = get_image_dataset(ds_name, res=224, split="train")

    train_features, train_sensitive, train_labels = prepare_features(ds_name, "train")
    test_features, test_sensitive, test_labels = prepare_features(ds_name, "test")

    # trns_wrapper = Image_Strategy(ResnetFeatures())
    # res_features = (
    #     trns_wrapper.transform_images(dataset).detach().cpu().numpy().squeeze()
    # )

    features_mean = train_features.mean(axis=0)
    features_std = train_features.std(axis=0)

    new_features = (train_features - features_mean) / features_std
    test_features = (test_features - features_mean) / features_std

    # debug(new_features)

    if full_features:
        ds_name = ds_name + "-full"
        file_name = f"{ds_name}_train.npz"
        np.savez_compressed(
            os.path.join(save_dir, file_name),
            features=new_features,
            metadata=train_sensitive,
            labels=train_labels,
        )

        file_name = f"{ds_name}_test.npz"
        np.savez_compressed(
            os.path.join(save_dir, file_name),
            features=test_features,
            metadata=test_sensitive,
            labels=test_labels,
        )
    else:
        pca = PCA(n_components=85)
        pca_features = pca.fit_transform(new_features)
        test_pca_features = pca.transform(test_features)

        pca_mean = pca_features.mean(axis=0)
        pca_std = pca_features.std(axis=0)

        pca_features = (pca_features - pca_mean) / pca_std
        test_pca_features = (test_pca_features - pca_mean) / pca_std

        file_name = f"{ds_name}_train.npz"
        np.savez_compressed(
            os.path.join(save_dir, file_name),
            features=pca_features,
            metadata=train_sensitive,
            labels=train_labels,
        )

        file_name = f"{ds_name}_test.npz"
        np.savez_compressed(
            os.path.join(save_dir, file_name),
            features=test_pca_features,
            metadata=test_sensitive,
            labels=test_labels,
        )

    # file_name = f"{ds_name}_train.npz"
    # np.savez_compressed(
    #     os.path.join(save_dir, file_name),
    #     features=new_features,
    #     sensitives=sensitive,
    #     labels=labels,
    # )

    # dataset = get_image_dataset(ds_name, res=224, split="test")

    # trns_wrapper = Image_Strategy(ResnetFeatures())
    # new_features = (
    #     trns_wrapper.transform_images(dataset).detach().cpu().numpy().squeeze()
    # )

    # new_features = new_features - new_features.mean(axis=0)
    # new_features = new_features / new_features.std(axis=0)

    # new_features = pca.transform(new_features)

    # sensitive = dataset.sensitives.numpy().squeeze()
    # labels = dataset.labels.numpy().squeeze()

    # file_name = f"{ds_name}_test.npz"
    # np.savez_compressed(
    #     os.path.join(save_dir, file_name),
    #     features=new_features,
    #     sensitives=sensitive,
    #     labels=labels,
    # )

    print("Done", file=sys.stderr)


if __name__ == "__main__":
    main()
