import os
from typing import Tuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.decomposition import PCA
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from sklearn.metrics import (
    ConfusionMatrixDisplay,
    accuracy_score,
    classification_report,
    confusion_matrix,
    roc_auc_score,
)
from sklearn.preprocessing import StandardScaler
from torch.utils.data import DataLoader, TensorDataset, random_split

from code_demeanor.logger import logger
from code_demeanor.utils import get_device, load_tensors_jsonl, set_seed


class MLP(nn.Module):
    def __init__(
        self, input_dim, output_dim=2, hidden=(128, 64), dropout=0.3, use_bn=True
    ):
        super().__init__()
        h1, h2 = hidden
        self.use_bn = use_bn

        self.fc1 = nn.Linear(input_dim, h1)
        self.bn1 = nn.BatchNorm1d(h1) if use_bn else nn.Identity()
        self.fc2 = nn.Linear(h1, h2)
        self.bn2 = nn.BatchNorm1d(h2) if use_bn else nn.Identity()
        self.fc3 = nn.Linear(h2, output_dim)
        self.drop = nn.Dropout(dropout)

    def forward(self, x):
        x = self.fc1(x)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.drop(x)

        x = self.fc2(x)
        x = self.bn2(x)
        x = F.relu(x)
        x = self.drop(x)

        x = self.fc3(x)  # logits
        return x


def evaluate(loader, model, device=get_device(verbose=False)):
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            preds = logits.argmax(dim=1)
            all_preds.append(preds.cpu())
            all_labels.append(y.cpu())
    y_true = torch.cat(all_labels)
    y_pred = torch.cat(all_preds)
    return accuracy_score(y_true, y_pred), y_true, y_pred


def get_stats_and_labels(
    token_stats_train_labels_path: str,
    token_stats_test_labels_path: str,
    token_stats_train_path: str,
    token_stats_test_path: str,
):
    token_stats_train_labels = load_tensors_jsonl(token_stats_train_labels_path)
    token_stats_train = load_tensors_jsonl(token_stats_train_path)
    token_stats_test_labels = load_tensors_jsonl(token_stats_test_labels_path)
    token_stats_test = load_tensors_jsonl(token_stats_test_path)

    train_stats = torch.stack([t[0] for t in token_stats_train])
    train_stats = train_stats.view(-1, token_stats_train[0].shape[1])
    train_stats = train_stats.squeeze(1)

    train_labels = torch.stack(
        [token_stats_train_labels[i][0] for i in range(len(token_stats_train_labels))]
    )
    train_labels = train_labels.view(-1)

    test_stats = torch.stack([t[0] for t in token_stats_test])
    test_stats = test_stats.view(-1, token_stats_test[0].shape[1])
    test_stats = test_stats.squeeze(1)

    test_labels = torch.stack(
        [token_stats_test_labels[i][0] for i in range(len(token_stats_test_labels))]
    )
    test_labels = test_labels.view(-1)
    logger.info(
        f"Train stats shape: {train_stats.shape}, Train labels shape: {train_labels.shape}"
    )
    logger.info(
        f"Test stats shape: {test_stats.shape}, Test labels shape: {test_labels.shape}"
    )

    return train_stats, train_labels, test_stats, test_labels


def normalise_stats(train_stats, test_stats):
    mean = train_stats.mean(dim=0, keepdim=True)
    std = train_stats.std(dim=0, keepdim=True) + 1e-6  # avoid division by zero
    train_stats = (train_stats - mean) / std
    test_stats = (test_stats - mean) / std
    return train_stats, test_stats


def get_pca_visualisation(
    train_norm,
    train_labels,
    base_output_name: str,
    show_figure: bool = True,
    save_fig: bool = True,
):
    X = train_norm.cpu().numpy()
    y = train_labels.cpu().numpy().astype(int)

    # Always scale before PCA
    Xz = StandardScaler().fit_transform(X)

    pca = PCA(n_components=2, whiten=False).fit(Xz)
    X_pca = pca.transform(Xz)
    ev = pca.explained_variance_ratio_

    # nicer scatter: small points, edges, alpha, legend
    colors = np.array(["#1f77b4", "#d62728"])  # red = 0, blue = 1 (adjust if reversed)
    labels = ["Normal Behavior", "Misbehavior"]
    plt.figure(figsize=(7, 5))
    for cls in [0, 1]:
        m = y == cls
        plt.scatter(
            X_pca[m, 0],
            X_pca[m, 1],
            s=36,
            alpha=0.55,
            edgecolors="k",
            linewidths=0.2,
            c=colors[cls],
            label=labels[cls],
        )

    plt.xlabel(f"PC1", fontsize=20)
    plt.ylabel(f"PC2", fontsize=20)
    plt.title("PCA (scaled features)", fontsize=22)
    plt.legend(frameon=True, fontsize=18, title_fontsize=18, loc="upper right")
    plt.tight_layout()
    # Save to PDF
    if save_fig:
        plt.savefig(base_output_name + "pca_scatter.pdf")
    if show_figure:
        plt.show()


def get_lda_visualisation(
    train_norm,
    train_labels,
    base_output_name: str,
    show_figure: bool = True,
    save_fig: bool = True,
):
    X = train_norm.cpu().numpy()
    y = train_labels.cpu().numpy().astype(int)

    # Always scale before LDA
    Xz = StandardScaler().fit_transform(X)

    lda = LDA(n_components=1).fit(Xz, y)
    X_lda = lda.transform(Xz).ravel()  # 1D

    plt.figure(figsize=(7, 4))
    bins = 40
    plt.hist(
        X_lda[y == 0], bins=bins, alpha=0.55, label="Normal Behavior", color="#1f77b4"
    )
    plt.hist(
        X_lda[y == 1],
        bins=bins,
        alpha=0.55,
        label="Misbehavior",
        color="#d62728",
    )
    plt.xlabel("LD1 (maximizes class separation)", fontsize=20)
    plt.ylabel("count", fontsize=20)
    plt.title("LDA 1-D projection", fontsize=22)
    plt.legend(fontsize=18, title_fontsize=18, loc="upper right")

    plt.tight_layout()
    plt.xlim([-40, 40])
    # Save to PDF
    if save_fig:
        plt.savefig(base_output_name + "lda_histogram.pdf")
    if show_figure:
        plt.show()


def train_mlp_and_evaluate(
    train_norm, train_labels, test_norm, test_labels, device=get_device(verbose=False)
):
    # -------------------- Reproducibility --------------------
    set_seed(42)

    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # -------------------- Data --------------------
    # Assumes you already have: train_norm, train_labels, test_norm, test_labels (tensors)
    # We’ll carve out a validation split from the training portion:
    full_train = TensorDataset(train_norm, train_labels)
    val_frac = 0.1
    val_size = int(len(full_train) * val_frac)
    train_size = len(full_train) - val_size
    train_ds, val_ds = random_split(
        full_train, [train_size, val_size], generator=torch.Generator()
    )

    train_loader = DataLoader(train_ds, batch_size=64, shuffle=True, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=256, shuffle=False, pin_memory=True)
    test_loader = DataLoader(
        TensorDataset(test_norm, test_labels),
        batch_size=256,
        shuffle=False,
        pin_memory=True,
    )

    # -------------------- Model --------------------
    # Your MLP class assumed available as MLP(...)
    MLP_model = MLP(
        input_dim=train_norm.shape[1],
        output_dim=2,  # change if >2 classes
        hidden=(128, 64),
        dropout=0.3,
        use_bn=True,
    ).to(DEVICE)

    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(MLP_model.parameters(), lr=1e-3, weight_decay=1e-4)

    # Optional: cosine schedule or StepLR if you like; kept simple here
    # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=80)

    # -------------------- Train / Validate --------------------
    def run_epoch(model, loader, train: bool) -> Tuple[float, float]:
        if train:
            model.train()
        else:
            model.eval()

        total_loss, total, correct = 0.0, 0, 0
        with torch.set_grad_enabled(train):
            for xb, yb in loader:
                xb, yb = xb.to(DEVICE), yb.to(DEVICE)
                if train:
                    optimizer.zero_grad()
                out = model(xb)  # logits (N, C)
                loss = criterion(out, yb)
                if train:
                    loss.backward()
                    optimizer.step()
                total_loss += loss.item() * xb.size(0)
                pred = out.argmax(dim=1)
                correct += (pred == yb).sum().item()
                total += yb.numel()
        return total_loss / total, correct / total

    best_val = -1.0
    best_path = "best_mlp.pt"
    epochs = 80
    patience = 10
    since_improved = 0

    for epoch in range(1, epochs + 1):
        train_loss, train_acc = run_epoch(MLP_model, train_loader, train=True)
        val_loss, val_acc = run_epoch(MLP_model, val_loader, train=False)

        # Select by validation accuracy (or choose AUC if that’s your paper’s primary metric)
        sel_metric = val_acc
        improved = sel_metric > best_val
        if improved:
            best_val = sel_metric
            since_improved = 0
            torch.save(MLP_model.state_dict(), best_path)
        else:
            since_improved += 1

        print(
            f"Epoch {epoch:03d} | "
            f"Train Loss: {train_loss:.4f} Acc: {train_acc:.4f} | "
            f"Val Loss: {val_loss:.4f} Acc: {val_acc:.4f} | "
            f"Best Val Acc: {best_val:.4f}"
        )

        if since_improved >= patience:
            print(f"Early stopping at epoch {epoch}. Best Val Acc: {best_val:.4f}")
            break
        # if scheduler: scheduler.step()

    # -------------------- Test (ONE TIME) --------------------
    # Load best checkpoint (by validation metric)
    MLP_model.load_state_dict(torch.load(best_path, map_location=DEVICE))
    MLP_model.eval()

    def collect_logits_and_labels(model, loader):
        logits_list, y_list = [], []
        with torch.no_grad():
            for xb, yb in loader:
                xb = xb.to(DEVICE)
                logits = model(xb)  # (N, C)
                logits_list.append(logits.cpu())  # <-- move to CPU
                y_list.append(yb.cpu())  # <-- also move to CPU
        logits = torch.cat(logits_list, dim=0)
        y_true = torch.cat(y_list, dim=0).numpy()
        return logits, y_true

    logits, y_true = collect_logits_and_labels(MLP_model, test_loader)
    probs = torch.softmax(logits, dim=1).numpy()
    y_pred = probs.argmax(axis=1)

    acc = accuracy_score(y_true, y_pred)

    # AUC: binary vs multi-class handling
    num_classes = probs.shape[1]
    auc = None
    try:
        if num_classes == 2:
            auc = roc_auc_score(y_true, probs[:, 1])
        else:
            # One-vs-Rest macro AUC (report your choice in the paper)
            auc = roc_auc_score(y_true, probs, multi_class="ovr", average="macro")
    except ValueError:
        # AUC can fail if a class is missing in y_true; report gracefully
        pass

    print(f"Test Accuracy: {acc:.4f}")
    if auc is not None:
        print(f"Test ROC-AUC: {auc:.4f}")

    print(classification_report(y_true, y_pred, digits=4))
    # remove best model file
    os.remove(best_path)

    return acc, auc, y_true, y_pred


def get_confusion_matrix_plot(
    y_true,
    y_pred,
    base_output_name: str,
    show_figure: bool = False,
    save_fig: bool = True,
):
    cm = confusion_matrix(y_true, y_pred, normalize="true")
    disp = ConfusionMatrixDisplay(
        confusion_matrix=cm, display_labels=["Normal Behavior", "Misbehavior"]
    )
    fig, ax = plt.subplots(figsize=(7, 6))
    disp.plot(cmap=plt.cm.Blues, values_format=".2f", ax=ax                )
              
    # Set font sizes for all text elements
    plt.title("Confusion Matrix", fontsize=22)
    ax.set_xlabel(ax.get_xlabel(), fontsize=22)
    ax.set_ylabel(ax.get_ylabel(), fontsize=22)
    ax.tick_params(axis="both", labelsize=16)
    # Set font size for text annotations inside the boxes
    for text in ax.texts:
        text.set_fontsize(16)
    plt.tight_layout()
    # Save to PDF
    if save_fig:
        plt.savefig(base_output_name + "confusion_matrix.pdf")
    if show_figure:
        plt.show()


def get_layer_wise_violin(
    train_stats,
    train_labels,
    n_heads: int,
    n_layers: int,
    base_output_name: str,
    show_figure: bool = True,
    save_fig: bool = True,
):

    # --- config ---
    N_HEADS = n_heads
    N_LAYERS = n_layers
    K = N_HEADS * N_LAYERS  # 672

    # train_stats: (N, 673), train_labels: (N,)
    # If your "extra" feature isn't the first column, adjust `offset`.
    offset = train_stats.shape[1] - K  # 673 - 672 = 1

    # Slice out the 24*28 head-layer logits
    X = train_stats[:, offset : offset + K].cpu().numpy()  # shape (N, 672)
    y = train_labels.cpu().numpy()
    N = X.shape[0]

    # Map flat feature index -> (layer, head) assuming features are ordered by heads within layer.
    # If your order is layer-major vs head-major, flip the // and % lines accordingly.
    layers = np.repeat(np.arange(N_LAYERS), N_HEADS)  # (672,)
    heads = np.tile(np.arange(N_HEADS), N_LAYERS)  # (672,)

    # Long DataFrame for seaborn
    df_long = pd.DataFrame(
        {
            "value": X.reshape(-1),  # (N*672,)
            "layer": np.tile(layers, N),  # (N*672,)
            "head": np.tile(heads, N),  # (N*672,)
            "label": np.repeat(y, K),  # (N*672,)
        }
    )

    # (Optional) If you want to aggregate heads per sample before plotting, uncomment:
    # df_long = (df_long.groupby(["label", "layer"])
    #                    .agg(value=("value", "mean"))
    #                    .reset_index())

    # Plot: violin per layer, split by class
    plt.figure(figsize=(16, 6))
    df_long["class_name"] = df_long["label"].map(
        {0: "Normal Behaviour", 1: "Misbehavior"}
    )
    sns.violinplot(
        data=df_long,
        x="layer",
        y="value",
        hue="class_name",
        split=True,
        inner="quart",
        palette={"Normal Behaviour": "lightblue", "Misbehavior": "lightcoral"},
    )
    plt.title("Layer-wise Logit Feature Distribution by Class", fontsize=22)
    plt.xlabel("Layer", fontsize=20)
    plt.ylabel("Intervention Effect", fontsize=20)
    plt.legend(title="Class", fontsize=18, title_fontsize=18, loc="upper right")

    plt.tight_layout()
    # Save to PDF
    if save_fig:
        plt.savefig(base_output_name + "layer_head_violin.pdf")
    if show_figure:
        plt.show()

def get_attention_layer_heatmap(
    train_stats,
    train_labels,
    n_heads: int,
    n_layers: int,
    base_output_name: str,
    show_figure: bool = True,
    save_fig: bool = True
    ):

    # --- config ---
    N_HEADS = n_heads
    N_LAYERS = n_layers
    K = N_HEADS * N_LAYERS  # 672

    # train_stats: (N, 673), train_labels: (N,)
    # If your "extra" feature isn't the first column, adjust `offset`.
    offset = train_stats.shape[1] - K  # 673 - 672 = 1

    # Slice out the 24*28 head-layer logits
    X = train_stats[:, offset : offset + K].cpu().numpy()  # shape (N, 672)
    y = train_labels.cpu().numpy()
    N = X.shape[0]

    # Map flat feature index -> (layer, head) assuming features are ordered by heads within layer.
    # If your order is layer-major vs head-major, flip the // and % lines accordingly.
    layers = np.repeat(np.arange(N_LAYERS), N_HEADS)  # (672,)
    heads = np.tile(np.arange(N_HEADS), N_LAYERS)  # (672,)

    # Long DataFrame for seaborn
    df_long = pd.DataFrame(
        {
            "value": X.reshape(-1),  # (N*672,)
            "layer": np.tile(layers, N),  # (N*672,)
            "head": np.tile(heads, N),  # (N*672,)
            "label": np.repeat(y, K),  # (N*672,)
        }
    )
    # PLot a heatmap of mean values per (layer, head)
    heatmap_data = df_long.groupby(['layer', 'head'])['value'].mean().unstack()
    plt.figure(figsize=(12, 8))
    sns.heatmap(heatmap_data, annot=False, fmt=".2f", cmap="viridis", vmax=2e-16)
    plt.title("Mean Intervention Effect by Layer and Head", fontsize=22)
    plt.xlabel("Head", fontsize=20)
    plt.ylabel("Layer", fontsize=20)
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)
    plt.tight_layout()
    # Save to PDF
    if save_fig:
        plt.savefig(base_output_name + "attention_layer_heatmap.pdf")
    if show_figure:
        plt.show()