from ..models.rangemodel import LightningRangeModel
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from skimage import exposure
import os
import fire
from ..config import cfg


def plot_embeds(model_path, method="ICA", save_name="ica.png", save_numpy=False):
    # method: one of 'ICA', 'TSNE'

    num_ds_dims = 3
    range_model = LightningRangeModel.load_from_checkpoint(
        model_path, map_location=torch.device("cpu")
    )
    pos_embeddings = (
        range_model.encoder.pos_embed.squeeze(0)
        .detach()
        .cpu()
        .numpy()
        .transpose(1, 2, 0)
    )
    mask = np.load(os.path.join("../data/masks", "ocean_mask.npy"))
    mask_inds = np.where(mask.reshape(-1) == 1)[0]
    feats = (
        nn.functional.interpolate(
            pos_embeddings, (mask.shape[0], mask.shape[1]), mode="bilinear"
        )
        .squeeze(0)
        .detach()
        .cpu()
        .numpy()
        .transpose(1, 2, 0)
    )
    feats = feats.reshape(feats.shape[0] * feats.shape[1], -1)
    if save_numpy:
        np.save("data/pos_embeds/pos_embeds_model.npy", feats)
    feats = feats[mask_inds, :]
    f_mu = feats.mean(0)
    f_std = feats.std(0)
    feats = feats - f_mu
    feats = feats / f_std
    assert not np.any(np.isnan(feats))
    assert not np.any(np.isinf(feats))

    if method == "ICA":
        from sklearn.decomposition import FastICA

        ICA = FastICA(
            n_components=num_ds_dims,
            random_state=cfg.train.seed,
            whiten="unit-variance",
            max_iter=1000,
        )
        feats_ds = ICA.fit_transform(feats)

    elif method == "PCA":
        from sklearn.decomposition import PCA

        pca = PCA(n_components=num_ds_dims, random_state=cfg.train.seed)
        feats_ds = pca.fit_transform(feats)

    else:
        raise NotImplementedError

    for cc in range(num_ds_dims):
        feats_ds[:, cc] = exposure.equalize_hist(feats_ds[:, cc])

    # convert into image
    op_im = np.ones((mask.shape[0] * mask.shape[1], num_ds_dims))
    op_im[mask_inds] = feats_ds
    op_im = op_im.reshape((mask.shape[0], mask.shape[1], num_ds_dims))

    # save output
    op_path = os.path.join("result_plots", save_name)
    print("Saving image to: " + op_path)
    plt.imsave(op_path, (op_im * 255).astype(np.uint8))


if __name__ == "__main__":
    fire.Fire(plot_embeds)
