import argparse
import yaml
import os
import torch
import copy
from pathlib import Path
from datetime import datetime
from tqdm import tqdm
from sklearn.cluster import KMeans

from gwdr.src.clust_dr import (
    Clust_then_DR,
    DR_then_Clust,
    COOTClust,
    GWDR,
    WrongParameter,
)
from gwdr.src.affinities import (
    GramAffinity,
    NormalizedGaussianAndStudentAffinity,
    EntropicAffinity,
    NormalizedLorentzHyperbolicAndStudentAffinity,
    UMAPAffinityIn,
    UMAPAffinityOut,
)
from gwdr.data.data import load_dataset
from gwdr.exps.plot_clust_dr import (
    compute_save_scores_batch_exp,
    save_best_scores_batch_exp,
    plot_scores_batch_exp,
    plot_losses_batch_exp,
)


model_dict = {
    "Clust_then_DR": Clust_then_DR,
    "DR_then_Clust": DR_then_Clust,
    "COOTClust": COOTClust,
    "GWDR": GWDR,
}


path_config_folder = os.getcwd() + "/runs"


def make_parser():
    parser = argparse.ArgumentParser()
    # General
    parser.add_argument("--exp_name", type=str, default="default")
    parser.add_argument("--config", type=str, default="")
    parser.add_argument("--log_dir", type=str, default=path_config_folder)
    parser.add_argument("--verbose", type=bool, default=False)
    parser.add_argument("--device", type=str, default="cpu")
    parser.add_argument("--dtype", type=torch.dtype, default=torch.double)
    parser.add_argument("--n_seeds", type=int, default=5)
    parser.add_argument("--max_output_sam", type=int, default=220)
    parser.add_argument("--step_output_sam", type=int, default=20)
    # Data
    parser.add_argument("--dataset", type=str, default="mnist")
    # Model
    parser.add_argument("--output_dim", type=int, default=2)
    parser.add_argument("--model", type=str, default="GWDR")
    parser.add_argument("--affinity_data", type=str, default="GramAffinity")
    parser.add_argument("--affinity_embedding", type=str, default="GramAffinity")
    parser.add_argument("--loss_fun", type=str, default="square_loss")
    parser.add_argument("--perplexity", type=int, default=50)
    # Optim
    parser.add_argument("--optimizer", type=str, default="Adam")
    parser.add_argument("--lr", type=float, default=1e0)
    parser.add_argument("--init", type=str, default="normal")
    parser.add_argument("--init_T", type=str, default="spectral")
    parser.add_argument("--entropic_reg", type=float, default=0.0)
    parser.add_argument("--tol", type=float, default=1e-6)
    parser.add_argument("--max_iter", type=int, default=1000)
    parser.add_argument("--max_iter_outer", type=int, default=20)
    parser.add_argument("--hyperbolic", type=bool, default=False)
    return parser


def make_args():
    parser = make_parser()
    args_dict = vars(parser.parse_args())
    args_list = []
    path_config = os.path.join(path_config_folder, args_dict["config"])
    for filename in os.listdir(path_config):
        if Path(filename).suffix.lower() in {".yaml", ".yml"}:
            # with open(os.path.join(path_config, filename), "r", encoding='us-ascii') as stream:
            with open(os.path.join(path_config, filename), "r") as stream:
                yml_args_list = yaml.safe_load_all(stream)
                # yml_args_list = list(yml_args_list)
                for yml_args in yml_args_list:
                    print("yml_args:", yml_args)
                    config_dict_to_add = copy.deepcopy(args_dict)
                    config_dict_to_add.update(yml_args)
                    path_log = build_exp_folder(config_dict_to_add)
                    if path_log is None:
                        continue
                    else:
                        config_dict_to_add["log_dir"] = path_log
                    args_list.append(config_dict_to_add)
    return (
        args_list,
        path_config,
        args_dict,
    )  # add arg_dict to still be able to mention device and macro settings like that


def build_exp_folder(args_dict):
    path_config = os.path.join(path_config_folder, args_dict["config"])
    path_log = os.path.join(path_config, args_dict["exp_name"])
    if os.path.isdir(path_log):
        return None
    os.mkdir(path_log)
    torch.save(args_dict, path_log + "/params.pt")
    return path_log


if __name__ == "__main__":

    exp_dict_list, path_config, args_dict = make_args()

    for exp_dict in tqdm(exp_dict_list, desc="running experiment batch"):

        X, Y = load_dataset(exp_dict["dataset"], device=exp_dict["device"])
        n_samples_init = X.shape[0]
        print(f"--- X shape: {X.shape} --- ")
        print(f"--- Y shape: {Y.shape} --- ")

        n_unique_labels = torch.unique(Y).shape[0]
        # if exp_dict['device'] != 'cpu':
        #    Y_cpu = Y.cpu()

        # Arguments that are shared accross models
        kwargs_model = {
            "output_dim": exp_dict["output_dim"],
            "optimizer": exp_dict["optimizer"],
            "lr": exp_dict["lr"],
            "init": exp_dict["init"],
            "verbose": exp_dict["verbose"],
            "tol": exp_dict["tol"],
            "max_iter": exp_dict["max_iter"],
            "device": exp_dict["device"],
            "dtype": exp_dict["dtype"],
        }
        print("exp_dict - init_T:", exp_dict["init_T"])
        # Add arguments that are specific to some models

        if exp_dict["model"] in ["COOTClust", "GWDR"]:
            kwargs_model["max_iter_outer"] = exp_dict["max_iter_outer"]

        if exp_dict["model"] in [
            "GWDR",
            "Clust_then_DR",
            "ClustGW_then_DR",
            "DR_then_Clust",
        ]:
            kwargs_model["loss_fun"] = exp_dict["loss_fun"]
            kwargs_model["init_T"] = exp_dict["init_T"]

            if exp_dict["affinity_data"] == "GramAffinity":
                kwargs_model["affinity_data"] = GramAffinity(centering=True)
                if exp_dict["init_T"] in ["spectral", "softspectral"]:
                    spectral_embeddings = torch.load(
                        f'affinities/{exp_dict["dataset"]}/{exp_dict["dataset"]}_gram_spectralembeddings.pt'
                    )

            elif exp_dict["affinity_data"] == "SymmetricEntropicAffinity":
                if not exp_dict["model"] == "Clust_then_DR":
                    kwargs_model["affinity_data"] = "precomputed"
                    affinity_path = f'affinities/{exp_dict["dataset"]}/{exp_dict["dataset"]}_{exp_dict["perplexity"]}.pt'
                    X = torch.load(affinity_path).to(device=exp_dict["device"])
                    print(f"loading affinity path: {affinity_path}")
                    print("new X.shape:", X.shape)

                if exp_dict["init_T"] in ["spectral", "softspectral"]:
                    spectral_embeddings = torch.load(
                        f'affinities/{exp_dict["dataset"]}/{exp_dict["dataset"]}_{exp_dict["perplexity"]}_spectralembeddings.pt'
                    )

            elif exp_dict["affinity_data"] == "EntropicAffinity":
                kwargs_model["affinity_data"] = EntropicAffinity(
                    perp=exp_dict["perplexity"]
                )

                if exp_dict["init_T"] in ["spectral", "softspectral"]:
                    spectral_embeddings = torch.load(
                        f'affinities/{exp_dict["dataset"]}/{exp_dict["dataset"]}_{exp_dict["perplexity"]}_spectralembeddings.pt'
                    )

            elif exp_dict["affinity_data"] == "UMAPAffinityIn":
                if not exp_dict["model"] == "Clust_then_DR":
                    kwargs_model["affinity_data"] = "precomputed"
                    affinity_path = f'affinities/{exp_dict["dataset"]}/{exp_dict["dataset"]}_UMAP_{exp_dict["perplexity"]}.pt'
                    X = torch.load(affinity_path).to(device=exp_dict["device"])
                    print(f"loading affinity path: {affinity_path}")
                    print("new X.shape:", X.shape)

                if exp_dict["init_T"] in ["spectral", "softspectral"]:
                    spectral_embeddings = torch.load(
                        f'affinities/{exp_dict["dataset"]}/{exp_dict["dataset"]}_UMAP_{exp_dict["perplexity"]}_spectralembeddings.pt'
                    )

            else:
                raise WrongParameter(
                    '"affinity_data" must be either "GramAffinity", "SymmetricEntropicAffinity" or "EntropicAffinity".'
                )

            if exp_dict["affinity_embedding"] == "GramAffinity":
                assert exp_dict["loss_fun"] == "square_loss"
                kwargs_model["affinity_embedding"] = GramAffinity()
            elif (
                exp_dict["affinity_embedding"] == "NormalizedGaussianAndStudentAffinity"
            ):
                assert exp_dict["loss_fun"] in ["kl_nomarg_loss", "kl_loss"]
                kwargs_model["affinity_embedding"] = (
                    NormalizedGaussianAndStudentAffinity()
                )
            elif (
                exp_dict["affinity_embedding"]
                == "NormalizedLorentzHyperbolicAndStudentAffinity"
            ):
                assert exp_dict["loss_fun"] in ["kl_nomarg_loss", "kl_loss"]
                kwargs_model["affinity_embedding"] = (
                    NormalizedLorentzHyperbolicAndStudentAffinity(student=True)
                )
                kwargs_model["optimizer"] = "RAdam"
                exp_dict["optimizer"] = "RAdam"
                kwargs_model["init"] = "WrappedNormal"
                exp_dict["init"] = "WrappedNormal"

            elif exp_dict["affinity_embedding"] == "UMAPAffinityOut":
                assert exp_dict["loss_fun"] == "binary_cross_entropy"
                kwargs_model["affinity_embedding"] = UMAPAffinityOut()

            else:
                raise WrongParameter(
                    '"affinity_embedding" must be either "GramAffinity", "NormalizedGaussianAndStudentAffinity" or "NormalizedLorentzHyperbolicAndStudentAffinity".'
                )

            if exp_dict["model"] == "GWDR":
                try:
                    kwargs_model["entropic_reg"] = float(
                        exp_dict["entropic_reg"]
                    )  # if benchmarking mirror descent solvers
                except:
                    pass

        saved_embeddings = {}
        saved_plans = {}
        saved_losses = {}

        output_sam_list = [
            i
            for i in range(
                n_unique_labels, exp_dict["max_output_sam"], exp_dict["step_output_sam"]
            )
        ]
        # output_sam_list = [n_unique_labels]

        for output_sam in tqdm(output_sam_list, desc="validating embedding size"):

            kwargs_model["output_sam"] = output_sam

            saved_embeddings[output_sam] = []
            saved_plans[output_sam] = []
            saved_losses[output_sam] = []

            seed_list = [i for i in range(exp_dict["n_seeds"])]

            if (exp_dict["init_T"] in ["spectral", "softspectral"]) and (
                exp_dict["model"] == "GWDR" or exp_dict["model"] == "Clust_then_DR"
            ):
                local_embeddings = spectral_embeddings[:, :output_sam]
                kmeans = KMeans(n_clusters=output_sam, random_state=0, n_init=10).fit(
                    local_embeddings
                )  # Apply kmeans on spectral embeddings as in Spectral clustering
                init_T = torch.eye(output_sam)[kmeans.labels_].to(
                    dtype=kwargs_model["dtype"], device=kwargs_model["device"]
                )

                if exp_dict["init_T"] == "softspectral":
                    h0 = torch.ones(
                        X.shape[0],
                        dtype=kwargs_model["dtype"],
                        device=kwargs_model["device"],
                    )
                    q = (
                        torch.ones(
                            output_sam,
                            dtype=kwargs_model["dtype"],
                            device=kwargs_model["device"],
                        )
                        / output_sam
                    )
                    init_T = (init_T + (h0[:, None] * q[None, :])) / 2.0
                kwargs_model["init_T"] = init_T

            if exp_dict["model"] == "Clust_then_DR" and exp_dict["affinity_data"] in [
                "SymmetricEntropicAffinity",
                "EntropicAffinity",
            ]:
                # if not "mode" in kwargs_model.keys():
                #    kwargs_model['mode'] = 'features'
                # if kwargs_model["mode"] == "features":
                # removing stuff from 'structure' based clustering.
                adjusted_perp = torch.max(
                    torch.Tensor(
                        [exp_dict["perplexity"] * output_sam / n_samples_init, 5]
                    )
                ).item()
                adjusted_perp = torch.min(
                    torch.Tensor([adjusted_perp, output_sam - 1])
                ).item()
                kwargs_model["affinity_data"] = EntropicAffinity(perp=adjusted_perp)

            if exp_dict["model"] == "Clust_then_DR" and exp_dict["affinity_data"] in [
                "UMAPAffinityIn"
            ]:
                # if not "mode" in kwargs_model.keys():
                #    kwargs_model['mode'] = 'features'
                # if kwargs_model["mode"] == "features":
                # removing stuff from 'structure' based clustering.
                adjusted_perp = torch.max(
                    torch.Tensor(
                        [exp_dict["perplexity"] * output_sam / n_samples_init, 5]
                    )
                ).item()
                adjusted_perp = torch.min(
                    torch.Tensor([adjusted_perp, output_sam - 1])
                ).item()
                kwargs_model["affinity_data"] = UMAPAffinityIn(
                    n_neighbors=int(adjusted_perp), max_iter=1000
                )

            for seed in seed_list:

                kwargs_model["seed"] = seed

                model = model_dict[exp_dict["model"]](**kwargs_model)
                if exp_dict["entropic_reg"] == 0.0:
                    Z = model.fit_transform(X)
                    assert Z.shape[0] == output_sam
                    assert Z.shape[1] == kwargs_model["output_dim"]
                    assert model.T.shape[0] == X.shape[0]
                    assert model.T.shape[1] == output_sam

                    if exp_dict["device"] == "cpu":
                        saved_embeddings[output_sam].append(Z)
                        saved_plans[output_sam].append(model.T)
                    else:
                        saved_embeddings[output_sam].append(Z.cpu())
                        saved_plans[output_sam].append(model.T.cpu())

                    saved_losses[output_sam].append(model.losses)
                else:
                    # optimization could fail because the entropy is too small
                    # hence we just catch the error and continue experiments
                    try:
                        Z = model.fit_transform(X)

                        assert Z.shape[0] == output_sam
                        assert Z.shape[1] == kwargs_model["output_dim"]
                        assert model.T.shape[0] == X.shape[0]
                        assert model.T.shape[1] == output_sam

                        if exp_dict["device"] == "cpu":
                            saved_embeddings[output_sam].append(Z)
                            saved_plans[output_sam].append(model.T)
                        else:
                            saved_embeddings[output_sam].append(Z.cpu())
                            saved_plans[output_sam].append(model.T.cpu())

                        saved_losses[output_sam].append(model.losses)
                    except:
                        saved_embeddings[output_sam].append(float("nan"))
                        saved_plans[output_sam].append(float("nan"))
                        saved_losses[output_sam].append(model.losses)

        torch.save([exp_dict, saved_embeddings], exp_dict["log_dir"] + "/embeddings.pt")
        torch.save([exp_dict, saved_plans], exp_dict["log_dir"] + "/plans.pt")
        torch.save([exp_dict, saved_losses], exp_dict["log_dir"] + "/losses.pt")

    # Saving and plotting results
    try:
        if "Hyperbolic" in exp_dict["affinity_embedding"]:
            hyperbolic = True
        else:
            hyperbolic = False
    except:
        hyperbolic = args_dict["hyperbolic"]
    print("hyperbolic:", hyperbolic)
    print("path config:", path_config)
    plot_losses_batch_exp(path_config)
    compute_save_scores_batch_exp(
        path_config,
        threshold=0,
        weighted=True,
        device=args_dict["device"],
        hyperbolic=hyperbolic,
    )
    save_best_scores_batch_exp(path_config)
    plot_scores_batch_exp(path_config, hyperbolic=hyperbolic)