import itertools
import logging
from collections import namedtuple
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Mapping, Sequence, Type, Tuple

import pandas as pd
from latent_invariances.stitching import DatasetConfig, DatasetParams
import torch
from datasets import DatasetDict, load_from_disk
from pytorch_lightning import seed_everything
from torch import nn
from tqdm import tqdm

from nn_core.common import PROJECT_ROOT

from latent_invariances.evaluation import evaluate_retrieval
from latent_invariances.utils.relreps import SIMPLE_PROJECTION_TYPE
from latent_invariances.utils.space import LatentSpace, RelativeSpace
from latent_invariances.modules.aggregation_modules import (
    Identity,
)

DEVICE: str = "cuda"
log = logging.getLogger(__name__)


OrigResult = namedtuple(
    "Result",
    ["seed", "projections", "num_anchors", "encoder", "score", "classifier", "aggregation"],
)


def data_config(
    dataset_name: str,
):
    domain2encoders = {
        "vision": [
            "vit_base_patch16_224",
            "rexnet_100",
            "vit_base_patch16_384",
            "vit_small_patch16_224",
            "vit_base_resnet50_384",
            "openai/clip-vit-base-patch32",
        ],
        "text": [
            "bert-base-cased",
            "bert-base-uncased",
            "google/electra-base-discriminator",
            "roberta-base",
            "albert-base-v2",
            "xlm-roberta-base",
            "openai/clip-vit-base-patch32",
        ],
    }
    if dataset_name == "fashion_mnist":
        perc: float = 1
        dataset_params: DatasetParams = DatasetParams("fashion_mnist", None, "train", "test", perc, ("fashion_mnist",))
        label_column: str = "label"
        encoding_column_template: str = "{encoder}"
        data_key = "_".join(
            map(str, [v for k, v in dataset_params._asdict().items() if k != "hf_key" and v is not None])
        )
        directory: Path = PROJECT_ROOT / "data" / "encoded_data" / f"{dataset_name}_train_test_{perc}"
        assert directory.exists()
        return DatasetConfig(
            key=data_key,
            directory=directory,
            label_column=label_column,
            encoding_column_template=encoding_column_template,
            encoders=domain2encoders["vision"],
        )

    if dataset_name == "mnist":
        perc: float = 1
        dataset_params: DatasetParams = DatasetParams("mnist", None, "train", "test", perc, ("mnist",))
        label_column: str = "label"
        encoding_column_template: str = "{encoder}"
        data_key = "_".join(
            map(str, [v for k, v in dataset_params._asdict().items() if k != "hf_key" and v is not None])
        )
        directory: Path = PROJECT_ROOT / "data" / "encoded_data" / f"{dataset_name}_train_test_{perc}"
        assert directory.exists()
        return DatasetConfig(
            key=data_key,
            directory=directory,
            label_column=label_column,
            encoding_column_template=encoding_column_template,
            encoders=domain2encoders["vision"],
        )

    if dataset_name == "cifar10":
        perc: float = 1
        dataset_params: DatasetParams = DatasetParams("cifar10", None, "train", "test", perc, ("cifar10",))
        label_column: str = "label"
        encoding_column_template: str = "{encoder}"
        data_key = "_".join(
            map(str, [v for k, v in dataset_params._asdict().items() if k != "hf_key" and v is not None])
        )
        directory: Path = PROJECT_ROOT / "data" / "encoded_data" / f"{dataset_name}_train_test_{perc}"
        assert directory.exists()
        return DatasetConfig(
            key=data_key,
            directory=directory,
            label_column=label_column,
            encoding_column_template=encoding_column_template,
            encoders=domain2encoders["vision"],
        )

    if dataset_name.startswith("cifar100"):
        perc: float = 1
        dataset_params: DatasetParams = DatasetParams(
            "cifar100", "fine" in dataset_name, "train", "test", perc, ("cifar100",)
        )
        label_column: str = "fine_label" if dataset_params.fine_grained else "coarse_label"
        encoding_column_template: str = "{encoder}"
        data_key = "_".join(
            map(str, [v for k, v in dataset_params._asdict().items() if k != "hf_key" and v is not None])
        )
        directory: Path = PROJECT_ROOT / "data" / "encoded_data" / f"{dataset_params.name}_train_test_{perc}"
        assert directory.exists()
        return DatasetConfig(
            key=data_key,
            directory=directory,
            label_column=label_column,
            encoding_column_template=encoding_column_template,
            encoders=domain2encoders["vision"],
        )

    if dataset_name == "trec":
        perc: float = 1
        dataset_params: DatasetParams = DatasetParams("trec", False, "train", "test", perc, ("trec",))
        label_column: str = "target"
        encoding_column_template: str = "{encoder}_mean_encoding"
        data_key = "_".join(
            map(str, [v for k, v in dataset_params._asdict().items() if k != "hf_key" and v is not None])
        )
        directory: Path = PROJECT_ROOT / "data" / "encoded_data" / data_key

        assert directory.exists()
        return DatasetConfig(
            key=data_key,
            directory=directory,
            label_column=label_column,
            encoding_column_template=encoding_column_template,
            encoders=domain2encoders["text"],
        )

    if dataset_name == "dbpedia_14":
        perc: float = 1
        dataset_params: DatasetParams = DatasetParams("dbpedia_14", None, "train", "test", perc, ("dbpedia_14",))
        label_column: str = "target"
        encoding_column_template: str = "{encoder}_mean_encoding"
        data_key = "_".join(
            map(str, [v for k, v in dataset_params._asdict().items() if k != "hf_key" and v is not None])
        )
        directory: Path = PROJECT_ROOT / "data" / "encoded_data" / data_key

        assert directory.exists()
        return DatasetConfig(
            key=data_key,
            directory=directory,
            label_column=label_column,
            encoding_column_template=encoding_column_template,
            encoders=domain2encoders["text"],
        )

    # if dataset_name == "amazon_reviews_multi":
    #     perc: float = 1
    #     dataset_params: DatasetParams = DatasetParams("amazon_reviews_multi", False, "train", "test", perc, ("amazon_reviews_multi", "all"))
    #     label_column: str = "fine_label" if dataset_params.fine_grained else "coarse_label"
    #     encoding_column_template: str = "lang_{encoder}_mean_encoding"
    #     data_key = "_".join(
    #         map(str, [v for k, v in dataset_params._asdict().items() if k != "hf_key" and v is not None])
    #     )
    #     directory: Path = PROJECT_ROOT / "data" / "encoded_data" / data_key

    #     assert directory.exists()
    #     return DatasetConfig(
    #         key=data_key,
    #         directory=directory,
    #         label_column=label_column,
    #         encoding_column_template=encoding_column_template,
    #         encoders=domain2encoders["text"],
    #     )

    if dataset_name == "n24news_text":
        dataset_params: DatasetParams = DatasetParams("n24news_text", False, "train", "test", 1, ("n24news",))
        label_column: str = "label"
        encoding_column_template: str = "{encoder}_cls_encoding"
        data_key = "_".join(
            map(str, [v for k, v in dataset_params._asdict().items() if k != "hf_key" and v is not None])
        )
        directory: Path = PROJECT_ROOT / "data" / "encoded_data" / "N24News"
        assert directory.exists()

        return DatasetConfig(
            key=data_key,
            directory=directory,
            label_column=label_column,
            encoding_column_template=encoding_column_template,
            encoders=[
                "bert-base-cased",
                "bert-base-uncased",
                "google/electra-base-discriminator",
                "roberta-base",
                "albert-base-v2",
                "xlm-roberta-base",
                "openai/clip-vit-base-patch32",
            ],
        )

    if dataset_name == "n24news_image":
        dataset_params: DatasetParams = DatasetParams("n24news_image", False, "train", "test", 1, ("n24news",))
        label_column: str = "label"
        encoding_column_template: str = "{encoder}"
        data_key = "_".join(
            map(str, [v for k, v in dataset_params._asdict().items() if k != "hf_key" and v is not None])
        )
        directory: Path = PROJECT_ROOT / "data" / "encoded_data" / "N24News"
        assert directory.exists()

        return DatasetConfig(
            key=data_key,
            directory=directory,
            label_column=label_column,
            encoding_column_template=encoding_column_template,
            encoders=(
                # "vit_base_patch16_224",
                "rexnet_100",
                "vit_base_patch16_384",
                "vit_small_patch16_224",
                "vit_base_resnet50_384",
                "cspdarknet53",
            ),
        )

    raise NotImplementedError


CompatibilityResult = namedtuple(
    "Result",
    [
        "seed",
        "encoding_space",
        "decoding_space",
        "aggregation",
        "projections",
        "num_anchors",
        "linear_cka",
        "l1",
        "mse",
        "cosine_sim",
        # "k",
        # "spearman",
        # "pearson",
        # "topk_jaccard",
        # "mrr",
        # "rbf_kernel_cka",
    ],
)


@torch.no_grad()
def run(
    seeds: Sequence[int],
    dataset: str,
    anchor_num_options: Sequence[int],
    device: torch.device,
    simple_projections: Mapping[str, Callable[[torch.Tensor, torch.Tensor], torch.Tensor]],
    projections_aggregation_to_use: Sequence[Tuple[Sequence[str], Type[nn.Module]]],
    num_probes: int = 5000,
    k: int = 10,
):
    dataset_config: DatasetConfig = data_config(dataset)
    enc_name2column = dataset_config.enc_name2column

    result_dir = PROJECT_ROOT / "results" / "compatibility" / dataset_config.key

    result_dir.mkdir(exist_ok=True, parents=True)
    results_path = result_dir / "results.tsv"

    results = []

    if results_path.exists():
        log.warning(f"Skipping {dataset}")
        return

    data: DatasetDict = load_from_disk(dataset_path=str(dataset_config.directory))
    if dataset_config.key.startswith("dbpedia_14"):
        data = DatasetDict(
            train=data["train"].train_test_split(train_size=0.1, stratify_by_column=dataset_config.label_column)[
                "train"
            ],
            test=data["test"].train_test_split(train_size=0.1, stratify_by_column=dataset_config.label_column)["train"],
        )

    tensor_columns = {
        column
        for column in data["train"].column_names
        if any(column.startswith(encoder) for encoder in dataset_config.encoders)
    }
    tensor_columns.add(dataset_config.label_column)
    data.set_format(columns=tensor_columns, output_all_columns=True, type="torch")

    fit_data = data["train"].train_test_split(train_size=0.9, seed=42, stratify_by_column=dataset_config.label_column)
    train_data, val_data, test_data = fit_data["train"], fit_data["test"], data["test"]
    data = {"train": train_data, "val": val_data, "test": test_data}

    num_classes = train_data.features[dataset_config.label_column].num_classes

    enc_name2abs_space: Mapping[str, Mapping[str, LatentSpace]] = {
        enc_name: {
            split: LatentSpace(
                encoding_type="absolute",
                vectors=data[split][enc_name2column[enc_name]],
                encoder=enc_name,
                keys=data[split]["index"],
                labels=data[split][dataset_config.label_column],
                num_classes=num_classes,
            )
            for split in ["train", "val", "test"]
        }
        for enc_name in dataset_config.encoders
    }

    for num_anchors in anchor_num_options:
        enc_name2anchors = {
            enc_name: enc_name2abs_space[enc_name]["train"].get_anchors(
                anchor_choice="uniform", seed=0, num_anchors=num_anchors
            )
            for enc_name in enc_name2abs_space.keys()
        }

        enc_name2split2simple_relative_projection2tensor_latent_rel_space: Mapping[str, Mapping[str, RelativeSpace]] = {
            enc_name: {
                split: {
                    simple_relative_projection_name: abs_space.to_relative(
                        projection_func=projection_func,
                        projection_name=simple_relative_projection_name,
                        anchors=enc_name2anchors[enc_name],
                    )
                    for simple_relative_projection_name, projection_func in simple_projections.items()
                }
                for split, abs_space in split2space.items()
            }
            for enc_name, split2space in enc_name2abs_space.items()
        }

        for seed, (projection_names, aggregation_type) in itertools.product(seeds, projections_aggregation_to_use):
            log.info(f"Running {dataset} | seed {seed} | {projection_names} | {aggregation_type.__name__}")
            enc_name2split2rel_space: Mapping[str, Mapping[str, RelativeSpace]] = {
                enc_name: {
                    split: RelativeSpace(
                        vectors=torch.cat(
                            [
                                enc_name2split2simple_relative_projection2tensor_latent_rel_space[enc_name][split][
                                    projection
                                ].vectors
                                for projection in projection_names
                            ],
                            dim=-1,
                        ),
                        keys=enc_name2split2simple_relative_projection2tensor_latent_rel_space[enc_name][split][
                            projection_names[0]
                        ].keys,
                        labels=enc_name2split2simple_relative_projection2tensor_latent_rel_space[enc_name][split][
                            projection_names[0]
                        ].labels,
                        encoder=enc_name2split2simple_relative_projection2tensor_latent_rel_space[enc_name][split][
                            projection_names[0]
                        ].encoder,
                        anchors=enc_name2split2simple_relative_projection2tensor_latent_rel_space[enc_name][split][
                            projection_names[0]
                        ].anchors,
                        num_classes=enc_name2split2simple_relative_projection2tensor_latent_rel_space[enc_name][split][
                            projection_names[0]
                        ].num_classes,
                        projection=",".join(sorted(projection_names)),
                    )
                    for split in enc_name2split2simple_relative_projection2tensor_latent_rel_space[enc_name].keys()
                }
                for enc_name, split2space in enc_name2abs_space.items()
            }

            # ... original peformance...
            # first, train absolute models to get original performance

            for (
                (enc_name1, rel_space1),
                (enc_name2, rel_space2),
            ) in tqdm(
                list(
                    itertools.product(
                        enc_name2split2rel_space.items(),
                        enc_name2split2rel_space.items(),
                    )
                )
            ):
                rel_space1 = rel_space1["test"]
                rel_space2 = rel_space2["test"]

                if rel_space1.shape[1] != rel_space2.shape[1]:
                    continue

                seed_everything(seed)

                # then, go the the various combinations for relative spaces and measure their performance...
                # decoder1 = enc_name2fit[enc_name1]["model"]
                # compatibility_info = evaluate_compatibility(
                #     latent_space1=rel_space1,
                #     latent_space2=rel_space2,
                #     device=device,
                # )
                # compatibility_info.update(
                #     **{
                #         "encoding_space": enc_name1,
                #         "decoding_space": enc_name2,
                #         "seed": seed,
                #         "projections": ",".join(sorted(projection_names)),
                #         "num_anchors": num_anchors,
                #         "aggregation": aggregation_type.__name__,
                #     }
                # )

                retrieval_info = evaluate_retrieval(
                    latent_space1=rel_space1,
                    latent_space2=rel_space2,
                    search_ids=data["test"]
                    .shuffle(seed=seed)
                    .select(list(range(min(num_probes, len(data["test"])))))["index"],
                    k=k,
                    device=device,
                )
                retrieval_info.update(
                    **{
                        "encoding_space": enc_name1,
                        "decoding_space": enc_name2,
                        "seed": seed,
                        "projections": ",".join(sorted(projection_names)),
                        "num_anchors": num_anchors,
                        "aggregation": aggregation_type.__name__,
                    }
                )
                # TODO: fix Faiss

                results.append(CompatibilityResult(**retrieval_info))
                pd.DataFrame(results).to_csv(results_path, sep="\t", index=False)

            pd.DataFrame(results).to_csv(results_path, sep="\t", index=False)


if __name__ == "__main__":
    seeds = tuple(range(3))
    datasets = [
        # "cifar10",
        "cifar100_coarse",
        # "fashion_mnist",
        "mnist",
        # "n24news_image",
        # "dbpedia_14",
        "trec",
        "n24news_text",
    ]
    projections_aggregation_to_use = [
        (["Absolute"], Identity),
        (["Cosine"], Identity),
        (["Euclidean"], Identity),
        (["CenterCosine"], Identity),
        (["NormEuclidean"], Identity),
        (["L1"], Identity),
        (["Linf"], Identity),
    ]
    anchor_num_options: Sequence[int] = (2000, 1280, 300, )

    # dinamically extract simple projection to compute
    simple_projections, _ = zip(*projections_aggregation_to_use)
    simple_projections = {i for lst in simple_projections for i in lst}
    simple_projections = {
        projection_name: SIMPLE_PROJECTION_TYPE[projection_name] for projection_name in simple_projections
    }

    for dataset in (pbar := tqdm((datasets))):
        # pbar.set_description(f"Running {dataset} | {classifier} | {num_anchors}")
        run(
            seeds=seeds,
            dataset=dataset,
            simple_projections=simple_projections,
            projections_aggregation_to_use=projections_aggregation_to_use,
            anchor_num_options=anchor_num_options,
            device="cuda",
        )
