import itertools
import logging
from collections import namedtuple
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Mapping, Sequence, Tuple, Type

import pandas as pd
import torch
from datasets import DatasetDict, load_from_disk
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import EarlyStopping
from sklearn.discriminant_analysis import StandardScaler
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import Normalizer
from sklearn.svm import SVC
from torch import nn
from torch.utils.data import DataLoader
from torchmetrics import Accuracy
from tqdm import tqdm

from nn_core.common import PROJECT_ROOT

from latent_invariances.evaluation import evaluate_retrieval
from latent_invariances.modules.aggregation_modules import LayerNorm, NonLinearSumAggregation, LinearSelfAttentionLayer
from latent_invariances.modules.simple_classifier import Classifier, SVCModel
from latent_invariances.utils.relreps import SIMPLE_PROJECTION_TYPE
from latent_invariances.utils.space import LatentSpace, RelativeSpace

DEVICE: str = "cuda"
log = logging.getLogger(__name__)

DatasetParams = namedtuple("DatasetParams", ["name", "fine_grained", "train_split", "test_split", "perc", "hf_key"])


@dataclass(frozen=True)
class DatasetConfig:
    key: str
    directory: Path
    label_column: str
    encoding_column_template: str
    encoders: Sequence[str]

    @property
    def enc_name2column(self):
        return {enc_name: self.encoding_column_template.format(encoder=enc_name) for enc_name in self.encoders}


def data_config(
    dataset_name: str,
):
    domain2encoders = {
        "vision": [
            # "vit_base_patch16_224",
            "rexnet_100",
            "vit_base_patch16_384",
            "vit_small_patch16_224",
            "vit_base_resnet50_384",
            "openai/clip-vit-base-patch32",
        ],
        "text": [
            "bert-base-cased",
            "bert-base-uncased",
            "google/electra-base-discriminator",
            "roberta-base",
            "albert-base-v2",
            "xlm-roberta-base",
            "openai/clip-vit-base-patch32",
        ],
    }
    if dataset_name == "fashion_mnist":
        perc: float = 1
        dataset_params: DatasetParams = DatasetParams("fashion_mnist", None, "train", "test", perc, ("fashion_mnist",))
        label_column: str = "label"
        encoding_column_template: str = "{encoder}"
        data_key = "_".join(
            map(str, [v for k, v in dataset_params._asdict().items() if k != "hf_key" and v is not None])
        )
        directory: Path = PROJECT_ROOT / "data" / "encoded_data" / f"{dataset_name}_train_test_{perc}"
        assert directory.exists()
        return DatasetConfig(
            key=data_key,
            directory=directory,
            label_column=label_column,
            encoding_column_template=encoding_column_template,
            encoders=domain2encoders["vision"],
        )

    if dataset_name == "mnist":
        perc: float = 1
        dataset_params: DatasetParams = DatasetParams("mnist", None, "train", "test", perc, ("mnist",))
        label_column: str = "label"
        encoding_column_template: str = "{encoder}"
        data_key = "_".join(
            map(str, [v for k, v in dataset_params._asdict().items() if k != "hf_key" and v is not None])
        )
        directory: Path = PROJECT_ROOT / "data" / "encoded_data" / f"{dataset_name}_train_test_{perc}"
        assert directory.exists()
        return DatasetConfig(
            key=data_key,
            directory=directory,
            label_column=label_column,
            encoding_column_template=encoding_column_template,
            encoders=domain2encoders["vision"],
        )

    if dataset_name == "cifar10":
        perc: float = 1
        dataset_params: DatasetParams = DatasetParams("cifar10", None, "train", "test", perc, ("cifar10",))
        label_column: str = "label"
        encoding_column_template: str = "{encoder}"
        data_key = "_".join(
            map(str, [v for k, v in dataset_params._asdict().items() if k != "hf_key" and v is not None])
        )
        directory: Path = PROJECT_ROOT / "data" / "encoded_data" / f"{dataset_name}_train_test_{perc}"
        assert directory.exists()
        return DatasetConfig(
            key=data_key,
            directory=directory,
            label_column=label_column,
            encoding_column_template=encoding_column_template,
            encoders=domain2encoders["vision"],
        )

    if dataset_name.startswith("cifar100"):
        perc: float = 1
        dataset_params: DatasetParams = DatasetParams(
            "cifar100", "fine" in dataset_name, "train", "test", perc, ("cifar100",)
        )
        label_column: str = "fine_label" if dataset_params.fine_grained else "coarse_label"
        encoding_column_template: str = "{encoder}"
        data_key = "_".join(
            map(str, [v for k, v in dataset_params._asdict().items() if k != "hf_key" and v is not None])
        )
        directory: Path = PROJECT_ROOT / "data" / "encoded_data" / f"{dataset_params.name}_train_test_{perc}"
        assert directory.exists()
        return DatasetConfig(
            key=data_key,
            directory=directory,
            label_column=label_column,
            encoding_column_template=encoding_column_template,
            encoders=domain2encoders["vision"],
        )

    if dataset_name == "trec":
        perc: float = 1
        dataset_params: DatasetParams = DatasetParams("trec", False, "train", "test", perc, ("trec",))
        label_column: str = "target"
        encoding_column_template: str = "{encoder}_mean_encoding"
        data_key = "_".join(
            map(str, [v for k, v in dataset_params._asdict().items() if k != "hf_key" and v is not None])
        )
        directory: Path = PROJECT_ROOT / "data" / "encoded_data" / data_key

        assert directory.exists()
        return DatasetConfig(
            key=data_key,
            directory=directory,
            label_column=label_column,
            encoding_column_template=encoding_column_template,
            encoders=domain2encoders["text"],
        )

    if dataset_name == "dbpedia_14":
        perc: float = 1
        dataset_params: DatasetParams = DatasetParams("dbpedia_14", None, "train", "test", perc, ("dbpedia_14",))
        label_column: str = "target"
        encoding_column_template: str = "{encoder}_mean_encoding"
        data_key = "_".join(
            map(str, [v for k, v in dataset_params._asdict().items() if k != "hf_key" and v is not None])
        )
        directory: Path = PROJECT_ROOT / "data" / "encoded_data" / data_key

        assert directory.exists()
        return DatasetConfig(
            key=data_key,
            directory=directory,
            label_column=label_column,
            encoding_column_template=encoding_column_template,
            encoders=domain2encoders["text"],
        )

    # if dataset_name == "amazon_reviews_multi":
    #     perc: float = 1
    #     dataset_params: DatasetParams = DatasetParams("amazon_reviews_multi", False, "train", "test", perc, ("amazon_reviews_multi", "all"))
    #     label_column: str = "fine_label" if dataset_params.fine_grained else "coarse_label"
    #     encoding_column_template: str = "lang_{encoder}_mean_encoding"
    #     data_key = "_".join(
    #         map(str, [v for k, v in dataset_params._asdict().items() if k != "hf_key" and v is not None])
    #     )
    #     directory: Path = PROJECT_ROOT / "data" / "encoded_data" / data_key

    #     assert directory.exists()
    #     return DatasetConfig(
    #         key=data_key,
    #         directory=directory,
    #         label_column=label_column,
    #         encoding_column_template=encoding_column_template,
    #         encoders=domain2encoders["text"],
    #     )

    if dataset_name == "n24news_text":
        dataset_params: DatasetParams = DatasetParams("n24news_text", False, "train", "test", 1, ("n24news",))
        label_column: str = "label"
        encoding_column_template: str = "{encoder}_cls_encoding"
        data_key = "_".join(
            map(str, [v for k, v in dataset_params._asdict().items() if k != "hf_key" and v is not None])
        )
        directory: Path = PROJECT_ROOT / "data" / "encoded_data" / "N24News"
        assert directory.exists()

        return DatasetConfig(
            key=data_key,
            directory=directory,
            label_column=label_column,
            encoding_column_template=encoding_column_template,
            encoders=[
                "bert-base-cased",
                "bert-base-uncased",
                "google/electra-base-discriminator",
                "roberta-base",
                "albert-base-v2",
                "xlm-roberta-base",
                "openai/clip-vit-base-patch32",
            ],
        )

    if dataset_name == "n24news_image":
        dataset_params: DatasetParams = DatasetParams("n24news_image", False, "train", "test", 1, ("n24news",))
        label_column: str = "label"
        encoding_column_template: str = "{encoder}"
        data_key = "_".join(
            map(str, [v for k, v in dataset_params._asdict().items() if k != "hf_key" and v is not None])
        )
        directory: Path = PROJECT_ROOT / "data" / "encoded_data" / "N24News"
        assert directory.exists()

        return DatasetConfig(
            key=data_key,
            directory=directory,
            label_column=label_column,
            encoding_column_template=encoding_column_template,
            encoders=(
                # "vit_base_patch16_224",
                "rexnet_100",
                "vit_base_patch16_384",
                "vit_small_patch16_224",
                "vit_base_resnet50_384",
                "cspdarknet53",
            ),
        )

    raise NotImplementedError


@torch.no_grad()
def test(num_classes: int, test_loader: DataLoader, model: nn.Module):
    score = Accuracy(task="multiclass", num_classes=num_classes, top_k=1).to(DEVICE)
    model.to(DEVICE)

    for test_batch in test_loader:
        x = test_batch["x"].to(DEVICE)
        y = test_batch["y"].to(DEVICE)

        logits = model(x)
        preds = torch.argmax(logits, dim=1)

        score.update(preds, y)

    return dict(score=score.compute().cpu().item())


@torch.no_grad()
def fit(
    subspace_dim: int,
    num_subspaces: int,
    num_classes: int,
    classifier_type: str,
    split2space: Mapping[str, LatentSpace],
    device: torch.device,
    aggregation_module: Type[nn.Module],
    seed: int = 42,
    batch_size: int = 10000,
    num_workers: int = 8,
    lr: float = 1e-3,  # TODO:change this
):
    seed_everything(seed)
    train_loader = DataLoader(
        split2space["train"],
        batch_size=batch_size,
        pin_memory=True,
        shuffle=True,
        num_workers=num_workers,
        persistent_workers=True,
    )
    val_loader = DataLoader(
        split2space["val"], batch_size=batch_size, pin_memory=True, shuffle=False, num_workers=num_workers
    )
    test_loader = DataLoader(split2space["test"], batch_size=3000, pin_memory=True, shuffle=False, num_workers=8)

    if classifier_type == "svm":
        model = make_pipeline(StandardScaler(), Normalizer(), SVC(gamma="auto", kernel="linear", random_state=seed))
        for batch in train_loader:
            X, y = batch["x"], batch["y"]
            model.fit(X, y)
        model = SVCModel(model)
    else:
        aggregator = aggregation_module(
            subspace_dim=subspace_dim,
            num_subspaces=num_subspaces,
        )

        model = Classifier(
            aggregation_module=aggregator,
            input_dim=aggregator.out_dim,
            num_classes=num_classes,
            lr=lr,
            deep=classifier_type == "mlp",
            seed=seed,
        ).to(device)

        trainer = Trainer(
            accelerator="auto",
            devices=1,
            max_epochs=50,  # TODO: change this
            logger=None,
            check_val_every_n_epoch=10,
            callbacks=[
                EarlyStopping(
                    monitor="accuracy",
                    verbose=True,
                    patience=1,
                    mode="max",
                )
            ],
            enable_progress_bar=True,
        )
        trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)

    # test the model
    test_info = test(
        num_classes=num_classes,
        test_loader=test_loader,
        model=model.eval(),
    )

    return dict(model=model.eval().cpu(), info=test_info)


OrigResult = namedtuple(
    "Result",
    ["seed", "projections", "num_anchors", "encoder", "score", "classifier", "aggregation"],
)


StitchingResult = namedtuple(
    "Result",
    [
        "seed",
        "encoding_space",
        "decoding_space",
        "aggregation",
        "classifier",
        "projections",
        "num_anchors",
        "score",
        "linear_cka",
        "l1",
        "mse",
        "cosine_sim",
        # "k",
        # "spearman",
        # "pearson",
        # "topk_jaccard",
        # "mrr",
        # "rbf_kernel_cka",
    ],
)


@torch.no_grad()
def run(
    dataset: str,
    classifiers: Sequence[str],
    anchor_num_options: Sequence[int],
    device: torch.device,
    seeds: Sequence[int],
    simple_projections: Mapping[str, Callable[[torch.Tensor, torch.Tensor], torch.Tensor]],
    projections_aggregation_to_use: Sequence[Tuple[Sequence[str], Type[nn.Module]]],
    num_probes: int = 5000,
    k: int = 10,
):
    dataset_config: DatasetConfig = data_config(dataset)
    enc_name2column = dataset_config.enc_name2column

    result_dir = PROJECT_ROOT / "mlp_rae" / "stitching" / dataset_config.key

    result_dir.mkdir(exist_ok=True, parents=True)
    original_result_path = result_dir / "original.tsv"
    stitching_results_path = result_dir / "stitching.tsv"

    orig_results = []
    stitching_results = []

    if original_result_path.exists() and stitching_results_path.exists():
        log.warning(f"Skipping {dataset}")
        return

    data: DatasetDict = load_from_disk(dataset_path=str(dataset_config.directory))
    if dataset_config.key.startswith("dbpedia_14"):
        data = DatasetDict(
            train=data["train"].train_test_split(train_size=0.1, stratify_by_column=dataset_config.label_column)[
                "train"
            ],
            test=data["test"].train_test_split(train_size=0.1, stratify_by_column=dataset_config.label_column)["train"],
        )

    tensor_columns = {
        column
        for column in data["train"].column_names
        if any(column.startswith(encoder) for encoder in dataset_config.encoders)
    }
    tensor_columns.add(dataset_config.label_column)
    data.set_format(columns=tensor_columns, output_all_columns=True, type="torch")

    fit_data = data["train"].train_test_split(train_size=0.9, seed=42, stratify_by_column=dataset_config.label_column)
    train_data, val_data, test_data = fit_data["train"], fit_data["test"], data["test"]
    data = {"train": train_data, "val": val_data, "test": test_data}

    num_classes = train_data.features[dataset_config.label_column].num_classes

    enc_name2abs_space: Mapping[str, Mapping[str, LatentSpace]] = {
        enc_name: {
            split: LatentSpace(
                encoding_type="absolute",
                vectors=data[split][enc_name2column[enc_name]],
                encoder=enc_name,
                keys=data[split]["index"],
                labels=data[split][dataset_config.label_column],
                num_classes=num_classes,
            )
            for split in ["train", "val", "test"]
        }
        for enc_name in dataset_config.encoders
    }

    for classifier_type, num_anchors in itertools.product(classifiers, anchor_num_options):
        enc_name2anchors = {
            enc_name: enc_name2abs_space[enc_name]["train"].get_anchors(
                anchor_choice="uniform", seed=0, num_anchors=num_anchors
            )
            for enc_name in enc_name2abs_space.keys()
        }

        enc_name2split2simple_relative_projection2tensor_latent_rel_space: Mapping[str, Mapping[str, RelativeSpace]] = {
            enc_name: {
                split: {
                    simple_relative_projection_name: abs_space.to_relative(
                        projection_func=projection_func,
                        projection_name=simple_relative_projection_name,
                        anchors=enc_name2anchors[enc_name],
                    )
                    for simple_relative_projection_name, projection_func in simple_projections.items()
                }
                for split, abs_space in split2space.items()
            }
            for enc_name, split2space in enc_name2abs_space.items()
        }

        for seed, (projection_names, aggregation_type) in itertools.product(seeds, projections_aggregation_to_use):
            log.info(
                f"Running {dataset} | {classifier_type} | seed {seed} | {projection_names} | {aggregation_type.__name__}"
            )
            enc_name2split2rel_space: Mapping[str, Mapping[str, RelativeSpace]] = {
                enc_name: {
                    split: RelativeSpace(
                        vectors=torch.cat(
                            [
                                enc_name2split2simple_relative_projection2tensor_latent_rel_space[enc_name][split][
                                    projection
                                ].vectors
                                for projection in projection_names
                            ],
                            dim=-1,
                        ),
                        keys=enc_name2split2simple_relative_projection2tensor_latent_rel_space[enc_name][split][
                            projection_names[0]
                        ].keys,
                        labels=enc_name2split2simple_relative_projection2tensor_latent_rel_space[enc_name][split][
                            projection_names[0]
                        ].labels,
                        encoder=enc_name2split2simple_relative_projection2tensor_latent_rel_space[enc_name][split][
                            projection_names[0]
                        ].encoder,
                        anchors=enc_name2split2simple_relative_projection2tensor_latent_rel_space[enc_name][split][
                            projection_names[0]
                        ].anchors,
                        num_classes=enc_name2split2simple_relative_projection2tensor_latent_rel_space[enc_name][split][
                            projection_names[0]
                        ].num_classes,
                        projection=",".join(sorted(projection_names)),
                    )
                    for split in enc_name2split2simple_relative_projection2tensor_latent_rel_space[enc_name].keys()
                }
                for enc_name, split2space in enc_name2abs_space.items()
            }

            # ... original peformance...
            # first, train absolute models to get original performance
            enc_name2fit = {
                enc_name: fit(
                    classifier_type=classifier_type,
                    split2space=enc_name2split2rel_space[enc_name],
                    seed=seed,
                    device=device,
                    subspace_dim=enc_name2split2simple_relative_projection2tensor_latent_rel_space[enc_name]["train"][
                        projection_names[0]
                    ].shape[1]
                    if projection_names == ["Absolute"]
                    else num_anchors,
                    num_subspaces=len(projection_names),
                    aggregation_module=aggregation_type,
                    num_classes=num_classes,
                )
                for enc_name in enc_name2split2simple_relative_projection2tensor_latent_rel_space.keys()
            }

            for encoder_name, fit_result in enc_name2fit.items():
                orig_results.append(
                    OrigResult(
                        seed=seed,
                        encoder=encoder_name,
                        score=fit_result["info"]["score"],
                        classifier=classifier_type,
                        aggregation=aggregation_type.__name__,
                        projections=",".join(sorted(projection_names)),
                        num_anchors=num_anchors,
                    )
                )
                pd.DataFrame(orig_results).to_csv(original_result_path, sep="\t", index=False)

            for (
                (enc_name1, rel_space1),
                (enc_name2, rel_space2),
            ) in tqdm(
                list(
                    itertools.product(
                        enc_name2split2rel_space.items(),
                        enc_name2split2rel_space.items(),
                    )
                )
            ):
                rel_space1 = rel_space1["test"]
                rel_space2 = rel_space2["test"]

                if rel_space1.shape[1] != rel_space2.shape[1]:
                    continue

                seed_everything(seed)

                # then, go the the various combinations for relative spaces and measure their performance...
                # decoder1 = enc_name2fit[enc_name1]["model"]
                decoder2 = enc_name2fit[enc_name2]["model"].to(DEVICE).eval()

                test_info = test(
                    num_classes=rel_space1.num_classes,
                    test_loader=DataLoader(rel_space1, batch_size=3000, pin_memory=True, shuffle=False, num_workers=8),
                    model=decoder2,
                )
                decoder2.cpu()

                eval_info = evaluate_retrieval(
                    latent_space1=rel_space1,
                    latent_space2=rel_space2,
                    search_ids=data["test"]
                    .shuffle(seed=seed)
                    .select(list(range(min(num_probes, len(data["test"])))))["index"],
                    k=k,
                    device=device,
                )
                eval_info.update(
                    **{
                        "encoding_space": enc_name1,
                        "decoding_space": enc_name2,
                        "seed": seed,
                        "projections": ",".join(sorted(projection_names)),
                        "num_anchors": num_anchors,
                        "classifier": classifier_type,
                        "aggregation": aggregation_type.__name__,
                    }
                )
                # TODO: fix Faiss

                stitching_results.append(StitchingResult(**test_info, **eval_info))
                pd.DataFrame(stitching_results).to_csv(stitching_results_path, sep="\t", index=False)

            pd.DataFrame(stitching_results).to_csv(stitching_results_path, sep="\t", index=False)


if __name__ == "__main__":
    seeds = tuple(range(3))
    datasets = [
        "cifar10",
        "cifar100_coarse",
        "fashion_mnist",
        "mnist",
        "n24news_image",
        "dbpedia_14",
        "trec",
        "n24news_text",
    ]
    projections_aggregation_to_use = [
        (["Absolute"], LayerNorm),
        (["Cosine"], LayerNorm),
        (["Euclidean"], LayerNorm),
        (["L1"], LayerNorm),
        (["Linf"], LayerNorm),
        #
        (["CenterCosine"], LayerNorm),
        (["NormEuclidean"], LayerNorm),
        #
        (["Cosine", "Euclidean"], NonLinearSumAggregation),
        (["Cosine", "L1"], NonLinearSumAggregation),
        (["Cosine", "Linf"], NonLinearSumAggregation),
        (["Cosine", "Euclidean"], LinearSelfAttentionLayer),
        (["Cosine", "L1"], LinearSelfAttentionLayer),
        (["Cosine", "Linf"], LinearSelfAttentionLayer),
        #
        (["Euclidean", "L1"], NonLinearSumAggregation),
        (["Euclidean", "Linf"], NonLinearSumAggregation),
        (["Euclidean", "L1"], LinearSelfAttentionLayer),
        (["Euclidean", "Linf"], LinearSelfAttentionLayer),
        #
        (["L1", "Linf"], NonLinearSumAggregation),
        (["L1", "Linf"], LinearSelfAttentionLayer),
        #
        (["Cosine", "Euclidean", "L1", "Linf"], NonLinearSumAggregation),
        (["Cosine", "Euclidean", "L1", "Linf"], LinearSelfAttentionLayer),
        # ablation different aggregation
        # (["Cosine", "Euclidean"], NonLinearSelfAttentionLayer),
        # (["Cosine", "L1"], NonLinearSelfAttentionLayer),
        # (["Cosine", "Linf"], NonLinearSelfAttentionLayer),
        # (["Cosine", "Euclidean"], ConcatAggregation),
        # (["Cosine", "L1"], ConcatAggregation),
        # (["Cosine", "Linf"], ConcatAggregation),
        # #
        # (["Euclidean", "L1"], NonLinearSelfAttentionLayer),
        # (["Euclidean", "Linf"], NonLinearSelfAttentionLayer),
        # (["Euclidean", "L1"], ConcatAggregation),
        # (["Euclidean", "Linf"], ConcatAggregation),
        # #
        # (["L1", "Linf"], NonLinearSelfAttentionLayer),
        # (["L1", "Linf"], ConcatAggregation),
        # #
        # (["Cosine", "Euclidean", "L1", "Linf"], NonLinearSelfAttentionLayer),
        # (["Cosine", "Euclidean", "L1", "Linf"], ConcatAggregation),
    ]
    # classifiers = ["mlp", "linear", "svm"]
    classifiers = ["linear"]
    anchor_num_options: Sequence[int] = (1280,)

    # dinamically extract simple projection to compute
    simple_projections, _ = zip(*projections_aggregation_to_use)
    simple_projections = {i for lst in simple_projections for i in lst}
    simple_projections = {
        projection_name: SIMPLE_PROJECTION_TYPE[projection_name] for projection_name in simple_projections
    }

    for dataset in (pbar := tqdm((datasets))):
        pbar.set_description(f"Running {dataset} | {classifiers} | {anchor_num_options}")
        run(
            seeds=seeds,
            dataset=dataset,
            simple_projections=simple_projections,
            projections_aggregation_to_use=projections_aggregation_to_use,
            classifiers=classifiers,
            anchor_num_options=anchor_num_options,
            device="cuda",
        )
