import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import itertools
import numpy as np
import os
from pathlib import Path
import pytest
import time
import lightning as L
import torch
from lightning.pytorch.callbacks import ModelCheckpoint, Callback
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from torch.utils.data import DataLoader
import pandas as pd
from datasets import DatasetDict, load_from_disk
from decision.xp.data.base import (
    ForwardedMixin,
    ds_registry,
    ds_rename,
    get_finetuned_dataset_path,
    get_finetuned_model_path,
    get_processed_dataset_path,
)
from decision.xp.model.base import PretrainedMixin, model_registry, model_rename
from decision.xp.model.hate import BaseSequenceClassification
from utils.io import save_path, save_fig
from utils.plot import set_latex_font


def forward_ds_model(ds: ForwardedMixin, model: PretrainedMixin, batch_size: int = 16):
    """Take a dataset and a model and forward the dataset through the model."""
    path = get_processed_dataset_path(ds, model)
    path.mkdir(parents=True, exist_ok=True)

    dataset = ds.load_dataset()

    # Move model to GPU if available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"DEVICE: {device}")
    model.eval()
    model.to(device)

    if device.type == "cuda":
        # Synchronize and start the timer
        torch.cuda.synchronize()
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        start.record()

    # In any case measure the cpu time
    start_wall = time.time()
    start_cpu = time.process_time()

    # Preprocess the images and prepare them for the model
    def preprocess_text(examples):
        inputs = ds.extract(examples)
        return model.process(inputs)

    def forward_images(examples):
        inputs = preprocess_text(examples)
        inputs = {k: v.to(device) for k, v in inputs.items()}

        with torch.no_grad():
            truncated_outputs, outputs = model.forward_both(**inputs)

            # Process the outputs to retrieve the relevant information
            latent_space = model.get_latent_space(truncated_outputs)
            logits = model.get_logits(outputs)
            probabilities = model.get_probabilities(outputs)
            y_pred = model.get_y_pred(outputs)
            label = ds.get_label(examples)

        return {
            "logits": logits,
            "probabilities": probabilities,
            "y_pred": y_pred,
            "latent_space": latent_space,
            "label": label,
        }

    # Apply preprocessing
    mapped_dataset = dataset.map(
        forward_images,
        batched=True,
        batch_size=batch_size,
        load_from_cache_file=False,  # No cache to enable recomputation
    )
    mapped_dataset.set_format(
        "pt",
        columns=[
            "probabilities",
            "y_pred",
            "latent_space",
            "logits",
            "label",
        ],
    )

    end_wall = time.time()
    end_cpu = time.process_time()

    elapsed_wall_time = end_wall - start_wall
    elapsed_cpu_time = end_cpu - start_cpu

    metrics = {
        "elapsed_wall_time": elapsed_wall_time,
        "elapsed_cpu_time": elapsed_cpu_time,
        "device": device,
        "n_samples": len(dataset),
    }

    if device.type == "cuda":
        # Synchronize and stop the timer
        end.record()
        torch.cuda.synchronize()
        elapsed_gpu_time = start.elapsed_time(end)  # Time in milliseconds

        metrics.update(
            {
                "elapsed_gpu_time": elapsed_gpu_time,
                "device_name": torch.cuda.get_device_name(),
            }
        )

    mapped_dataset.save_to_disk(path)

    return metrics


def ds_forwarded_exists(
    ds: ForwardedMixin,
    model: PretrainedMixin,
    check_length: bool = True,
):
    """Check if the dataset has already been forwarded."""
    forwarded_path = get_processed_dataset_path(ds, model)
    try:
        forwarded_dataset = load_from_disk(forwarded_path)
    except FileNotFoundError:
        return False

    expected_features = {
        "logits",
        "probabilities",
        "y_pred",
        "latent_space",
        "label",
    }

    if not expected_features.issubset(forwarded_dataset.features):
        return False

    if not check_length:
        return True

    # Perform additional check on the number of samples
    original_dataset = ds.load_dataset()

    n_forwarded = forwarded_dataset.num_rows
    n_original = original_dataset.num_rows

    print(n_forwarded, n_original)

    if n_forwarded != n_original:
        return False

    return True


def ds_finetuned_exists(ds: ForwardedMixin, model: PretrainedMixin):
    """Check if the dataset has already been forwarded."""
    finetuned_path = get_finetuned_dataset_path(ds, model)
    try:
        finetuned_dataset = load_from_disk(finetuned_path)
    except FileNotFoundError:
        print(f"File not found {finetuned_path}")
        return False

    splits = set(finetuned_dataset.keys())

    if not splits.issubset(["train", "val", "test"]):
        print(f"Splits not subset of train/val/test {splits}")
        return False

    original_dataset = ds.load_dataset()

    n_finetuned = np.sum(list(finetuned_dataset.num_rows.values()))
    n_original = original_dataset.num_rows

    print(n_finetuned, n_original)

    if n_finetuned != n_original:
        print(f"Number of samples not equal: {n_finetuned} != {n_original}")
        return False

    expected_features = {
        "logits",
        "probabilities",
        "y_pred",
        "latent_space",
        "label",
        "finetuned_probabilities",
        "finetuned_y_pred",
        "finetuned_logits",
    }

    for split in splits:
        if not expected_features.issubset(finetuned_dataset[split].features):
            print(f"Features not subset of expected features for split {split}")
            return False

    return True


model_names = [
    "cnerg1",
    "cnerg2",
    "cnerg3",
    "cnerg4",
    "cnerg5",
    # "fb_roberta1",
    "fb_roberta2",
    # "mistral_instruct",
]


@pytest.mark.parametrize("model_name", model_names)
def test_ds_forwarded_exists(model_name):
    ds_name = "hate"
    # model_name = "cnerg5"
    model = model_registry[model_name]()
    ds = ds_registry[ds_name]()

    ds_name2 = ds_rename[ds_name]
    model_name2 = model_rename[model_name]

    check = ds_forwarded_exists(ds, model)
    print(f"{ds_name} ({ds_name2})", f"{model_name} ({model_name2})", check)


def check_forward_table() -> pd.DataFrame:
    """For each pair of dataset and model, check if the model has already been
    forwarded on this dataset."""
    ds_names = ds_registry.keys()
    model_names = model_registry.keys()

    check_table = np.full((len(ds_names), len(model_names)), False)

    for (i, ds_name), (j, model_name) in itertools.product(
        enumerate(ds_names), enumerate(model_names)
    ):
        model = model_registry[model_name]()
        ds = ds_registry[ds_name]()

        if not isinstance(ds, ForwardedMixin):
            print(f"{ds_name} is not a ForwardedMixin")
            continue

        ds_name2 = ds_rename.get(ds_name, ds_name)
        model_name2 = model_rename.get(model_name, model_name)
        print(f"{ds_name} ({ds_name2})", f"{model_name} ({model_name2})")

        check = ds_forwarded_exists(ds, model)
        check_table[i, j] = check
        print(check)

    check_table = pd.DataFrame(check_table, index=ds_names, columns=model_names)

    return check_table


def plot_forward_table(df: pd.DataFrame):
    fig, ax = plt.subplots()

    # Plotting
    cmap = ListedColormap(["tab:red", "tab:green"])
    plt.imshow(df, cmap=cmap, aspect="auto")

    plt.xticks(ticks=np.arange(len(df.columns)), labels=df.columns)
    plt.yticks(ticks=np.arange(len(df.index)), labels=df.index)

    ax.set(
        xlabel="Model",
        ylabel="Dataset",
        title="Forwarded datasets",
    )

    return fig


def test_ds_forwarded_exists_all(out):
    check_table = check_forward_table()
    path = save_path(out, "csv", "forwarded")
    check_table.to_csv(path)


def test_plot_forward_table(out):
    # ds_names = ds_registry.keys()
    # model_names = model_registry.keys()
    # print(list(ds_names))
    # print(list(model_names))

    check_table = pd.read_csv(
        Path(out).parent / "ds_forwarded_exists_all/forwarded.csv", index_col=0
    )

    # data = {
    #     "A": [True, False, True],
    #     "B": [False, True, False],
    #     "C": [True, False, False],
    # }
    # df = pd.DataFrame(data)

    set_latex_font()
    fig = plot_forward_table(check_table)
    save_fig(fig, out)


@pytest.mark.parametrize("model_name", model_names)
def test_ds_finetuned_exists(model_name):
    ds_name = "hate"
    # model_name = "cnerg5"
    model = model_registry[model_name]()
    ds = ds_registry[ds_name]()

    check = ds_finetuned_exists(ds, model)
    print(ds_name, model_name, check)


ds_names = [
    # "hate",
    # "hate_merged_en",
    # "hate_merged_en2",
    # "hate_merged_no_en",
    # "hate_merged_large_en",
    # "hate_dyn_gen",
    # "merged_hate_check",
    # "hate_merged_no_en2",
    # "hate_merged_large_no_en",
    # "hate_merged_large",
    "hate_en_tweets",
    "hate_en_speech18",
    "hate_en_speech_off",
    "hate_en_davidson",
    "hate_en_gender",
    "hate_en_frenk",
    "hate_en_check",
    "hate_en_twitter",
    "hate_en_open",
]


@pytest.mark.parametrize("model_name", model_names)
@pytest.mark.parametrize("ds_name", ds_names)
def test_fwd(model_name, ds_name, out, batch_size=50, skip_existing=False):
    # model_name = "cnerg1"
    # ds_name = "hate_merged_no_en"
    model = model_registry[model_name]()
    ds = ds_registry[ds_name]()

    print(model)
    print(ds)

    if skip_existing and ds_forwarded_exists(ds, model):
        pytest.skip(f"Forwarded dataset already exists for {ds_name} and {model_name}.")

    times = forward_ds_model(ds, model, batch_size=batch_size)
    times.update(
        {
            "datetime_created": time.time(),
            "operation": "forward",
            "model_name": model.model_name,
            "ds_name": ds.ds_name,
            "model_name_query": model_name,
            "ds_name_query": ds_name,
        }
    )

    path = Path(os.environ["WORKING_DIR"], "datasets/forwarded/times.csv")
    # Load pandas df from csv at path if it exists or create a new one
    if path.exists():
        df = pd.read_csv(path)
    else:
        df = pd.DataFrame()

    df = pd.concat([df, pd.DataFrame([times])])
    df.to_csv(path, index=False)
    path2 = Path(save_path(out, "csv", "times"))
    df.to_csv(path2, index=False)


@pytest.mark.parametrize(
    "model_name",
    [
        "mistral_instruct",
    ],
)
def test_load_model_gpu(model_name):
    model: BaseSequenceClassification = model_registry[model_name]()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"DEVICE: {device}")
    m = model.load_whole_model()
    # model.eval()
    # model.to(device)
    m.to(device)


def finetune(ds: ForwardedMixin, model: PretrainedMixin, pred_batch_size: int = 100):
    dataset_path = get_processed_dataset_path(ds, model)
    model_dirpath = get_finetuned_model_path(ds, model)
    finetuned_path = get_finetuned_dataset_path(ds, model)

    dataset = load_from_disk(dataset_path)

    # if model.model_name == "mistralai/Mistral-7B-Instruct-v0.2":

    #     def fix_mistral(examples):
    #         latent_space = examples["latent_space"]
    #         # latent_space = torch.as_tensor(latent_space)
    #         # for x in latent_space:
    #         #     print(x.shape)
    #         # latent_space = torch.stack(latent_space, axis=0)[:, -1, :]
    #         latent_space = torch.stack([x[-1, :] for x in latent_space], axis=0)
    #         # print(latent_space.shape)
    #         # latent_space = latent_space[:, -1, :]
    #         return {
    #             "latent_space": latent_space,
    #         }

    #     dataset = dataset.map(fix_mistral, batched=True, batch_size=100)
    #     path = get_processed_dataset_path(ds, model)
    #     # if os.path.exists(path):
    #     #     shutil.rmtree(path)
    #     # path = path / "2"
    #     path.mkdir(parents=True, exist_ok=True)
    #     dataset.save_to_disk(path)

    # Check that "latent_space" and "label" are in the dataset
    assert "latent_space" in dataset.features
    assert "label" in dataset.features

    groups = ds.get_groups(dataset)
    idx_train, idx_val, idx_test = ds.get_train_val_test_split(groups)

    train_dataset = dataset.select(idx_train)
    val_dataset = dataset.select(idx_val)
    test_dataset = dataset.select(idx_test)

    # Create the train/val/test splits
    # split_dataset = dataset.train_test_split(train_size=0.5)
    # trainval_dataset = split_dataset["train"]
    # test_dataset = split_dataset["test"]
    # split_dataset = trainval_dataset.train_test_split(train_size=0.5)
    # train_dataset = split_dataset["train"]
    # val_dataset = split_dataset["test"]

    split_dataset = DatasetDict(
        {
            "train": train_dataset,
            "val": val_dataset,
            "test": test_dataset,
        }
    )

    class DataModule(L.LightningDataModule):
        def __init__(self, batch_size=8):
            super().__init__()
            self.batch_size = batch_size

        def train_dataloader(self):
            return DataLoader(train_dataset, shuffle=True, batch_size=self.batch_size)

        def val_dataloader(self):
            return DataLoader(val_dataset, batch_size=self.batch_size)

        def test_dataloader(self):
            return DataLoader(test_dataset, batch_size=self.batch_size)

        def predict_dataloader(self):
            return DataLoader(split_dataset, batch_size=pred_batch_size)

    class Model(L.LightningModule):
        def __init__(self, remaining_model):
            super().__init__()
            self.remaining_model = remaining_model

        def forward(self, x):
            return model._forward_remaining(self.remaining_model, x)

        def training_step(self, batch, batch_idx):
            x = batch["latent_space"]
            y = batch["label"].long()
            output = self(x)
            logits = model.get_logits(output)
            loss = torch.nn.functional.cross_entropy(logits, y)
            self.log("train_loss", loss)
            return loss

        def validation_step(self, batch, batch_idx):
            x = batch["latent_space"]
            y = batch["label"].long()
            output = self(x)
            logits = model.get_logits(output)
            loss = torch.nn.functional.cross_entropy(logits, y)
            self.log("val_loss", loss)
            return loss

        def predict_step(self, batch, batch_idx):
            x = batch["latent_space"]
            output = self(x)
            logits = model.get_logits(output)
            probabilities = model.get_probabilities(output)
            y_pred = model.get_y_pred(output)
            return {
                "finetuned_probabilities": probabilities,
                "finetuned_y_pred": y_pred,
                "finetuned_logits": logits,
            }

        def configure_optimizers(self):
            optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
            lr_scheduler = {
                "scheduler": torch.optim.lr_scheduler.StepLR(
                    optimizer, step_size=1, gamma=0.7
                ),
                "name": "learning_rate",
                "interval": "epoch",
                "frequency": 1,
            }
            return [optimizer], [lr_scheduler]

    data_module = DataModule(batch_size=8)
    remaining_model = model.remaining_model
    pl_model = Model(remaining_model=remaining_model)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.train()

    class ElapsedTimeCallback(Callback):
        def on_train_start(self, trainer, pl_module):
            self.start_wall = time.time()
            self.start_cpu = time.process_time()
            self.device = pl_module.device
            if pl_module.device.type == "cuda":
                self.start_event = torch.cuda.Event(enable_timing=True)
                self.end_event = torch.cuda.Event(enable_timing=True)
                self.start_event.record()
                self.device_name = torch.cuda.get_device_name()
            else:
                self.start_event = None
                self.device_name = None

        def on_train_end(self, trainer, pl_module):
            self.end_wall = time.time()
            self.end_cpu = time.process_time()
            self.elapsed_cpu_time = self.end_cpu - self.start_cpu
            self.elapsed_wall_time = self.end_wall - self.start_wall
            if self.start_event:
                self.end_event.record()
                torch.cuda.synchronize()  # Wait for the events to be recorded!
                self.elapsed_gpu_time = self.start_event.elapsed_time(self.end_event)
            else:
                self.elapsed_gpu_time = None

    elapsed_time_callback = ElapsedTimeCallback()

    early_stop_callback = EarlyStopping(
        monitor="val_loss", min_delta=0.00, patience=3, verbose=False, mode="min"
    )
    checkpoint_callback = ModelCheckpoint(
        monitor="val_loss",
        dirpath=model_dirpath,  # Specify your directory to save the model
        filename="best-checkpoint-{epoch:02d}-{val_loss:.6f}",
        save_top_k=1,
        mode="min",
    )
    trainer = L.Trainer(
        max_epochs=20,
        callbacks=[early_stop_callback, checkpoint_callback, elapsed_time_callback],
    )  # Set gpus=0 for CPU training
    trainer.fit(pl_model, data_module)

    # del pl_model.model
    # del pl_model.truncated_model
    # del pl_model
    # gc.collect()
    torch.cuda.empty_cache()

    pl_best_model = Model.load_from_checkpoint(
        checkpoint_callback.best_model_path, remaining_model=remaining_model
    )
    pl_best_model.eval()

    for split in ["train", "val", "test"]:
        res = trainer.predict(
            pl_best_model, DataLoader(split_dataset[split], batch_size=100)
        )
        res = {k: torch.cat([d[k] for d in res]) for k in res[0].keys()}
        for k in res:
            split_dataset[split] = split_dataset[split].add_column(k, res[k].tolist())

    split_dataset.save_to_disk(finetuned_path)
    print("Saving finetuned dataset to", finetuned_path)

    metrics = {
        "elapsed_wall_time": elapsed_time_callback.elapsed_wall_time,
        "elapsed_cpu_time": elapsed_time_callback.elapsed_cpu_time,
        "device": elapsed_time_callback.device,
        "elapsed_gpu_time": elapsed_time_callback.elapsed_gpu_time,
        "device_name": elapsed_time_callback.device_name,
        "n_samples_train": len(train_dataset),
        "n_samples_val": len(val_dataset),
        "n_samples_test": len(test_dataset),
        "n_samples_predict": len(split_dataset),
    }

    return metrics


model_names = [
    # "cnerg1",
    # "cnerg2",
    # "cnerg3",
    # "cnerg4",
    # "cnerg5",
    "fb_roberta1",
    "fb_roberta2",
    # "mistral_instruct",
]

ds_names = [
    "hate",
    "hate_merged_en",
    "hate_merged_no_en",
    "hate_merged_en2",
    "hate_dyn_gen",
    "merged_hate_check",
    "hate_merged_no_en2",
    "hate_merged_large_en",
    "hate_merged_large_no_en",
    "hate_merged_large",
]


@pytest.mark.parametrize("model_name", model_names)
@pytest.mark.parametrize("ds_name", ds_names)
def test_finetune(model_name, ds_name, out):
    # model_name = "cnerg2"
    # ds_name = "hate_merged_en"
    model = model_registry[model_name]()
    ds = ds_registry[ds_name]()

    print(model)
    print(ds)

    # if ds_finetuned_exists(ds, model):
    #     pytest.skip(f"Finetuned dataset already exists for {ds_name} and {model_name}.")

    times = finetune(ds, model)
    times.update(
        {
            "datetime_created": time.time(),
            "operation": "finetune",
            "model_name": model.model_name,
            "ds_name": ds.ds_name,
            "model_name_query": model_name,
            "ds_name_query": ds_name,
        }
    )

    path = Path(os.environ["WORKING_DIR"], "datasets/finetuned/times.csv")
    # Load pandas df from csv at path if it exists or create a new one
    if path.exists():
        df = pd.read_csv(path)
    else:
        df = pd.DataFrame()

    df = pd.concat([df, pd.DataFrame([times])])
    df.to_csv(path, index=False)
    path2 = Path(save_path(out, "csv", "times"))
    df.to_csv(path2, index=False)


def test_dataset():
    model_name = "cnerg1"
    ds_name = "hate"
    model = model_registry[model_name]()
    ds = ds_registry[ds_name]()
    finetuned = False

    if finetuned:
        dataset = ds.get_dataset(model, finetuned=finetuned)
        groups_val = ds.get_groups(dataset["val"])
        groups_test = ds.get_groups(dataset["test"])
        groups = np.concatenate([groups_val, groups_test])
    else:
        dataset = ds.get_dataset(model, finetuned=finetuned)
        groups = ds.get_groups(dataset)
    # idx_train, idx_val, idx_test = ds.get_splits(dataset)
    idx_train, idx_val, idx_test = ds.get_train_val_test_split(groups)

    def assert_no_intersection(groups, idx_train, idx_val, idx_test):
        groups_train = groups[idx_train]
        groups_val = groups[idx_val]
        groups_test = groups[idx_test]

        groups_train = set(groups_train)
        groups_val = set(groups_val)
        groups_test = set(groups_test)

        assert len(groups_train.intersection(groups_val)) == 0
        assert len(groups_val.intersection(groups_test)) == 0

    assert_no_intersection(groups, idx_train, idx_val, idx_test)

    (X, y, S, G), (idx_train, idx_val, idx_test) = ds.get_arrays(model, finetuned)
    assert_no_intersection(G, idx_train, idx_val, idx_test)
