import gc
import os
import sys
from argparse import ArgumentParser
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch
from bert_score import score

sys.path.append(str(Path(__file__).parent.parent))
import json

from settings.subsets import SUBSETS
from settings.text_transformations import TEXT_TRANSFORMATIONS
from utils.plotting import plot_train_test_val
from utils.utils import (
    find_empty,
    save_metrics4,
    stats_diff_train_test_val,
    write_results_to_csv,
)


class SimpleDataset:
    def __init__(self, data):
        self.data = data
        self.column_names = list(data[0].keys()) if data else []

    def __getitem__(self, key):
        return [row[key] for row in self.data]

    def __len__(self):
        return len(self.data)

    def __iter__(self):
        return iter(self.data)

    def select(self, indices):
        return SimpleDataset([self.data[i] for i in indices])


def my_load_dataset(path):
    data = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            data.append(json.loads(line))
    return SimpleDataset(data)


def get_data_for_max_val_train(train_path, val_path):
    data_files = {}
    if os.path.exists(train_path):
        data_files["train"] = train_path
    if os.path.exists(val_path):
        data_files["val"] = val_path
    if not data_files:
        print(f"No data files found for subset {subset}")
        return
    train = my_load_dataset(train_path) if "train" in data_files else None
    train_val = my_load_dataset(val_path) if "val" in data_files else None
    print("Loaded datasets:")
    print(f"Train set size: {len(train) if train else 0}")
    print(f"Train_val set size: {len(train_val) if train_val else 0}")
    print("\nProcessing data...")

    find_empty(train, message=f"{subset} train set")
    find_empty(train_val, message=f"{subset} train_val set")
    # Split train_val into test and test_val
    n = len(train_val)
    half = n // 2
    test = train_val.select(range(half))
    test_val = train_val.select(range(half, n))

    return train, train_val, test, test_val


def main(args):
    print("\nFetching data...", flush=True)
    safe_model_name = args.model_name.replace("/", "_")
    base_path = os.path.join(args.data_dir, safe_model_name)
    subset = args.subset_name

    if getattr(args, "max_val_train", False):
        if getattr(args, "between_predictions_bert_score", False):
            baseline_pred_train_path = os.path.join(
                base_path, subset, f"train_{args.n_suff}suff_predic.jsonl"
            )
            baseline_pred_val_path = os.path.join(
                base_path, subset, f"val_{args.n_suff}suff_predic.jsonl"
            )
            baseline_train, baseline_train_val, baseline_test, baseline_test_val = (
                get_data_for_max_val_train(
                    baseline_pred_train_path, baseline_pred_val_path
                )
            )

            for transformation in TEXT_TRANSFORMATIONS:
                transformed_pred_train_path = os.path.join(
                    base_path,
                    subset,
                    transformation,
                    f"train_{args.n_suff}suff_predic.jsonl",
                )
                transformed_pred_val_path = os.path.join(
                    base_path,
                    subset,
                    transformation,
                    f"val_{args.n_suff}suff_predic.jsonl",
                )
                (
                    transformed_train,
                    transformed_train_val,
                    transformed_test,
                    transformed_test_val,
                ) = get_data_for_max_val_train(
                    transformed_pred_train_path, transformed_pred_val_path
                )
                n = len(transformed_train_val)
                half = n // 2

                if getattr(args, "only_plots_and_tables", False):
                    print("\nReading BERT scores...", flush=True)
                    train_metrics_path = os.path.join(
                        args.metrics_folder,
                        safe_model_name,
                        subset,
                        "train_metrics.json",
                    )
                    val_metrics_path = os.path.join(
                        args.metrics_folder, safe_model_name, subset, "val_metrics.json"
                    )
                    metric_key = f"bert_{args.n_suff}_{transformation}"
                    with open(train_metrics_path, "r") as f:
                        train_metrics = json.load(f)
                    with open(val_metrics_path, "r") as f:
                        val_metrics = json.load(f)
                    F_train = torch.tensor(train_metrics[metric_key])
                    F_train_val = torch.tensor(val_metrics[metric_key])
                    F_test = F_train_val[:half]
                    F_test_val = F_train_val[half:]
                else:
                    print("\nComputing BERT scores...", flush=True)
                    F_train = F_train_val = F_test = F_test_val = None
                    _, _, F_train = score(
                        baseline_train["predic"], transformed_train["predic"], lang="en"
                    )
                    _, _, F_train_val = score(
                        baseline_train_val["predic"],
                        transformed_train_val["predic"],
                        lang="en",
                    )

                    F_test = F_train_val[:half]
                    F_test_val = F_train_val[half:]

                    print(
                        f"BERT F1 scores: train={F_train.mean().item() if F_train is not None else 'NA'}; train_val={F_train_val.mean().item() if F_train_val is not None else 'NA'}; test={F_test.mean().item() if F_test is not None else 'NA'}; test_val={F_test_val.mean().item() if F_test_val is not None else 'NA'}"
                    )
                    # Save per-sample metrics
                    if args.metrics_folder:
                        save_metrics4(
                            "train",
                            f"bert_{args.n_suff}_{transformation}",
                            F_train,
                            os.path.join(args.metrics_folder, safe_model_name, subset),
                        )
                        save_metrics4(
                            "val",
                            f"bert_{args.n_suff}_{transformation}",
                            F_train_val,
                            os.path.join(args.metrics_folder, safe_model_name, subset),
                        )

                print("\nStatistical tests...", flush=True)
                from utils.utils import stats_diff_train_val_test_val

                test_results = stats_diff_train_val_test_val(
                    F_train,
                    F_train_val,
                    F_test,
                    F_test_val,
                    alternative="less",
                    verbosity=2,
                )
                results_path = os.path.join(args.result_output, safe_model_name)
                if not os.path.exists(results_path):
                    os.makedirs(results_path)
                write_results_to_csv(
                    os.path.join(
                        results_path, f"{args.n_suff}suff_{transformation}.csv"
                    ),
                    {
                        "subset_name": subset,
                        "train_size": (
                            len(transformed_train) if transformed_train else 0
                        ),
                        "train_val_size": (
                            len(transformed_train_val) if transformed_train_val else 0
                        ),
                        "test_size": len(transformed_test) if transformed_test else 0,
                        "test_val_size": (
                            len(transformed_test_val) if transformed_test_val else 0
                        ),
                        "train_BERT_F1_mean": (
                            f"{F_train.mean().item():.4f}"
                            if F_train is not None
                            else "NA"
                        ),
                        "train_val_BERT_F1_mean": (
                            f"{F_train_val.mean().item():.4f}"
                            if F_train_val is not None
                            else "NA"
                        ),
                        "test_BERT_F1_mean": (
                            f"{F_test.mean().item():.4f}"
                            if F_test is not None
                            else "NA"
                        ),
                        "test_val_BERT_F1_mean": (
                            f"{F_test_val.mean().item():.4f}"
                            if F_test_val is not None
                            else "NA"
                        ),
                        **{
                            f"{test}_{comparison}": f"{p_value:.4e}"
                            for test, results in test_results.items()
                            for comparison, p_value in results.items()
                        },
                    },
                )
                print(f"Results saved to {results_path}", flush=True)
                print("\nPlotting the results...", flush=True)
                plot_path = os.path.join(
                    args.result_output,
                    safe_model_name,
                    f"{args.n_suff}suff_plots",
                    transformation,
                )
                if not os.path.exists(plot_path):
                    os.makedirs(plot_path)

                def plot_and_save(arr1, arr2, labels, title_prefix, save_prefix):
                    from utils.plotting import plot_pairwise

                    fig, ax = plot_pairwise(
                        np.array(arr1), np.array(arr2), labels=labels
                    )
                    plt.title(
                        f"{title_prefix} Distribution ({labels[0]} vs {labels[1]}) for {args.subset_name}"
                    )
                    plot_path_train = os.path.join(
                        plot_path, f"{save_prefix}_{args.subset_name}.png"
                    )
                    plt.savefig(plot_path_train)
                    print(f"Plot saved to {plot_path_train}", flush=True)
                    plt.close(fig)

                if (
                    F_train is not None
                    and F_train_val is not None
                    and F_test is not None
                    and F_test_val is not None
                ):
                    plot_and_save(
                        F_train,
                        F_train_val,
                        labels=["Train", "Train_Val"],
                        title_prefix="BERT F1",
                        save_prefix="train_vs_trainval",
                    )
                    plot_and_save(
                        F_test,
                        F_test_val,
                        labels=["Test", "Test_Val"],
                        title_prefix="BERT F1",
                        save_prefix="test_vs_testval",
                    )
                else:
                    print("Not enough data to plot.")

        else:
            train_path = os.path.join(
                base_path, subset, f"train_{args.n_suff}suff_predic.jsonl"
            )
            val_path = os.path.join(
                base_path, subset, f"val_{args.n_suff}suff_predic.jsonl"
            )
            train, train_val, test, test_val = get_data_for_max_val_train(
                train_path, val_path
            )
            n = len(train_val)
            half = n // 2

            print("\nComputing BERT scores...", flush=True)
            F_train = F_train_val = F_test = F_test_val = None
            _, _, F_train = score(train["predic"], train["suffix"], lang="en")
            _, _, F_train_val = score(
                train_val["predic"], train_val["suffix"], lang="en"
            )

            F_test = F_train_val[:half]
            F_test_val = F_train_val[half:]

            print(
                f"BERT F1 scores: train={F_train.mean().item() if F_train is not None else 'NA'}; train_val={F_train_val.mean().item() if F_train_val is not None else 'NA'}; test={F_test.mean().item() if F_test is not None else 'NA'}; test_val={F_test_val.mean().item() if F_test_val is not None else 'NA'}"
            )
            # Save per-sample metrics
            if args.metrics_folder:
                save_metrics4(
                    "train",
                    f"bert_{args.n_suff}",
                    F_train,
                    os.path.join(args.metrics_folder, safe_model_name, subset),
                )
                save_metrics4(
                    "val",
                    f"bert_{args.n_suff}",
                    F_train_val,
                    os.path.join(args.metrics_folder, safe_model_name, subset),
                )

            print("\nStatistical tests...", flush=True)
            from utils.utils import stats_diff_train_val_test_val

            test_results = stats_diff_train_val_test_val(
                F_train,
                F_train_val,
                F_test,
                F_test_val,
                alternative="greater",
                verbosity=2,
            )
            results_path = os.path.join(args.result_output, safe_model_name)
            if not os.path.exists(results_path):
                os.makedirs(results_path)
            write_results_to_csv(
                os.path.join(results_path, f"{args.n_suff}suff.csv"),
                {
                    "subset_name": subset,
                    "train_size": len(train) if train else 0,
                    "train_val_size": len(train_val) if train_val else 0,
                    "test_size": len(test) if test else 0,
                    "test_val_size": len(test_val) if test_val else 0,
                    "train_BERT_F1_mean": (
                        f"{F_train.mean().item():.4f}" if F_train is not None else "NA"
                    ),
                    "train_val_BERT_F1_mean": (
                        f"{F_train_val.mean().item():.4f}"
                        if F_train_val is not None
                        else "NA"
                    ),
                    "test_BERT_F1_mean": (
                        f"{F_test.mean().item():.4f}" if F_test is not None else "NA"
                    ),
                    "test_val_BERT_F1_mean": (
                        f"{F_test_val.mean().item():.4f}"
                        if F_test_val is not None
                        else "NA"
                    ),
                    **{
                        f"{test}_{comparison}": f"{p_value:.4e}"
                        for test, results in test_results.items()
                        for comparison, p_value in results.items()
                    },
                },
            )
            print(f"Results saved to {results_path}", flush=True)
            print("\nPlotting the results...", flush=True)
            plot_path = os.path.join(
                args.result_output, safe_model_name, f"{args.n_suff}suff_plots"
            )
            if not os.path.exists(plot_path):
                os.makedirs(plot_path)

            def plot_and_save(arr1, arr2, labels, title_prefix, save_prefix):
                from utils.plotting import plot_pairwise

                fig, ax = plot_pairwise(np.array(arr1), np.array(arr2), labels=labels)
                plt.title(
                    f"{title_prefix} Distribution ({labels[0]} vs {labels[1]}) for {args.subset_name}"
                )
                plot_path_train = os.path.join(
                    plot_path, f"{save_prefix}_{args.subset_name}.png"
                )
                plt.savefig(plot_path_train)
                print(f"Plot saved to {plot_path_train}", flush=True)
                plt.close(fig)

            if (
                F_train is not None
                and F_train_val is not None
                and F_test is not None
                and F_test_val is not None
            ):
                plot_and_save(
                    F_train,
                    F_train_val,
                    labels=["Train", "Train_Val"],
                    title_prefix="BERT F1",
                    save_prefix="train_vs_trainval",
                )
                plot_and_save(
                    F_test,
                    F_test_val,
                    labels=["Test", "Test_Val"],
                    title_prefix="BERT F1",
                    save_prefix="test_vs_testval",
                )
            else:
                print("Not enough data to plot.")
            return

    # ------------------------------- Normal case -------------------------------
    else:
        train_path = os.path.join(
            base_path, subset, f"train_{args.n_suff}suff_predic.jsonl"
        )
        test_path = os.path.join(
            base_path, subset, f"test_{args.n_suff}suff_predic.jsonl"
        )
        val_path = os.path.join(
            base_path, subset, f"val_{args.n_suff}suff_predic.jsonl"
        )

        data_files = {}
        if os.path.exists(train_path):
            data_files["train"] = train_path
        if os.path.exists(test_path):
            data_files["test"] = test_path
        if os.path.exists(val_path):
            data_files["val"] = val_path

        if not data_files:
            print(f"No data files found for subset {subset}")
            return

        train = my_load_dataset(train_path) if "train" in data_files else None
        test = my_load_dataset(test_path) if "test" in data_files else None
        val = my_load_dataset(val_path) if "val" in data_files else None

        print("Loaded datasets:")
        print(f"Train set size: {len(train) if train else 0}")
        print(f"Test set size: {len(test) if test else 0}")
        print(f"Val set size: {len(val) if val else 0}")

        print("\nProcessing data...")

        if train:
            find_empty(train, message=f"{subset} train set")
        if test:
            find_empty(test, message=f"{subset} test set")
        if val:
            find_empty(val, message=f"{subset} val set")

        print("\nComputing BERT scores...", flush=True)
        F_train = F_test = F_val = None
        if train:
            _, _, F_train = score(train["predic"], train["suffix"], lang="en")
        if test:
            _, _, F_test = score(test["predic"], test["suffix"], lang="en")
        if val:
            _, _, F_val = score(val["predic"], val["suffix"], lang="en")
        print(
            f"BERT F1 scores: train={F_train.mean().item() if F_train is not None else 'NA'}; test={F_test.mean().item() if F_test is not None else 'NA'}; val={F_val.mean().item() if F_val is not None else 'NA'}"
        )

        # Save per-sample metrics
        if args.metrics_folder:
            save_metrics4(
                "train",
                f"bert_{args.n_suff}",
                F_train,
                os.path.join(args.metrics_folder, safe_model_name, subset),
            )
            save_metrics4(
                "test",
                f"bert_{args.n_suff}",
                F_test,
                os.path.join(args.metrics_folder, safe_model_name, subset),
            )
            save_metrics4(
                "val",
                f"bert_{args.n_suff}",
                F_val,
                os.path.join(args.metrics_folder, safe_model_name, subset),
            )

        print("\nStatistical tests...", flush=True)
        test_results = stats_diff_train_test_val(
            F_train, F_test, F_val, alternative="greater", verbosity=2
        )

        results_path = os.path.join(args.result_output, safe_model_name)
        if not os.path.exists(results_path):
            os.makedirs(results_path)

        write_results_to_csv(
            os.path.join(results_path, f"{args.n_suff}suff.csv"),
            {
                "subset_name": subset,
                "train_size": len(train) if train else 0,
                "test_size": len(test) if test else 0,
                "val_size": len(val) if val else 0,
                "train_BERT_F1_mean": (
                    f"{F_train.mean().item():.4f}" if F_train is not None else "NA"
                ),
                "test_BERT_F1_mean": (
                    f"{F_test.mean().item():.4f}" if F_test is not None else "NA"
                ),
                "val_BERT_F1_mean": (
                    f"{F_val.mean().item():.4f}" if F_val is not None else "NA"
                ),
                **{
                    f"{test}_{comparison}": f"{p_value:.4e}"
                    for test, results in test_results.items()
                    for comparison, p_value in results.items()
                },
            },
        )
        print(f"Results saved to {results_path}", flush=True)

        print("\nPlotting the results...", flush=True)
        plot_path = os.path.join(
            args.result_output, safe_model_name, f"{args.n_suff}suff_plots"
        )
        if not os.path.exists(plot_path):
            os.makedirs(plot_path)

        if F_train is not None and F_test is not None and F_val is not None:
            fig, ax = plot_train_test_val(
                np.array(F_train),
                np.array(F_test),
                np.array(F_val),
                labels=["Train", "Test", "Val"],
            )
            plt.title(f"BERT F1 Score Distribution for {subset}")
            plot_path = os.path.join(plot_path, f"dist_{subset}.png")
            plt.savefig(plot_path)
            print(f"Plot saved to {plot_path}", flush=True)
        else:
            print("Not enough data to plot.")


if __name__ == "__main__":
    parser = ArgumentParser(description="Script running catshift baseline method.")

    parser.add_argument(
        "--data_dir",
        type=str,
        required=True,
        help="Directory containing the data files",
    )
    parser.add_argument(
        "--model_name", type=str, default="EleutherAI/pythia-410m-deduped"
    )
    parser.add_argument("--n_suff", type=int, default=64)  # 64, 32
    parser.add_argument(
        "--result_output", type=str, required=True, help="Path to save the results."
    )
    parser.add_argument(
        "--subset_name", type=str, default="enron", help="Name of the subset to process"
    )
    parser.add_argument(
        "--metrics_folder",
        type=str,
        default=None,
        help="Folder to save per-sample metrics (BERTScore)",
    )
    parser.add_argument(
        "--max-val-train",
        action="store_true",
        help="If set, process files generated by split_dataset.py with --max-val-train: train_full.jsonl, train_val.jsonl, test.jsonl, test_val.jsonl. Otherwise, process train.jsonl, val.jsonl, test.jsonl.",
    )
    parser.add_argument(
        "--between-predictions-bert-score",
        action="store_true",
        help="If set, BERT score between normal prediction and one based on transformed prefix is calculated",
    )
    parser.add_argument("--only-plots-and-tables", action="store_true", help="TODO")
    args = parser.parse_args()
    print(args, flush=True)

    for i, subset in enumerate(SUBSETS):
        args.subset_name = subset
        print("-" * 64)
        print(f"Processing subset: {subset}", flush=True)
        main(args)

        gc.collect()
        torch.cuda.empty_cache()

        print("\n", flush=True)
        print(f"{subset:<20} DONE | ({(i+1)/len(SUBSETS):.1%})", flush=True)
