import os
from collections import defaultdict
from typing import Iterable, List, Optional, Tuple, Union

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F

ArrayLike = Union[List[float], np.ndarray, torch.Tensor]

# -----------------------------
# Small, reusable utilities
# -----------------------------
def to_numpy(x: ArrayLike) -> np.ndarray:
    """Return a 1D numpy array (detached if torch)."""
    if isinstance(x, torch.Tensor):
        x = x.detach().cpu().numpy()
    return np.asarray(x).ravel()


def to_tensor(x: ArrayLike, dtype: torch.dtype = torch.float32) -> torch.Tensor:
    """Return a 1D torch tensor."""
    if isinstance(x, torch.Tensor):
        return x.to(dtype=dtype).flatten()
    if isinstance(x, np.ndarray):
        return torch.from_numpy(x).to(dtype=dtype).flatten()
    return torch.tensor(x, dtype=dtype).flatten()


def is_all_none(lst: Iterable[Optional[object]]) -> bool:
    return all(el is None for el in lst)


def cumulative_neurons_before(layer_idx: int, sample_dict: dict) -> int:
    """Sum neuron counts of layers with index < layer_idx."""
    return sum(v for k, v in sample_dict.items() if k < layer_idx)


def moving_average_1d(x: ArrayLike, window: int = 3) -> Optional[torch.Tensor]:
    """
    Moving average for 1D data. Returns None if window > len(x).
    Accepts list/np/torch and returns torch tensor.
    """
    t = to_tensor(x)
    if window < 1 or len(t) < window:
        return None
    if window == 1:
        return t.clone()
    w = torch.ones(1, 1, window, dtype=t.dtype, device=t.device) / window
    y = F.conv1d(t.view(1, 1, -1), w, stride=1)
    return y.flatten()


def normalize_0_1(x: ArrayLike) -> torch.Tensor:
    t = to_tensor(x)
    xmin, xmax = torch.min(t), torch.max(t)
    if xmax == xmin:
        return torch.zeros_like(t)
    return (t - xmin) / (xmax - xmin)


# -----------------------------
# Time-series extrema helpers
# -----------------------------
def first_local_extremum(time_series: ArrayLike, patience: int = 10, kind: str = "max") -> int:
    """
    Return the index of the first local extremum found with a patience rule,
    but only *after* a confirming trend:
      - kind="max": start tracking only after an ascent (t[i] > t[i-1])
      - kind="min": start tracking only after a descent (t[i] < t[i-1])

    Patience counts consecutive non-improvements after the current best.
    If no valid extremum is found, returns 0 (safe for downstream indexing).
    """
    t = to_tensor(time_series)
    n = len(t)
    if n <= 1:
        return 0

    if kind == "max":
        best_val = -torch.inf
        best_idx = -1
        has_trend = False
        no_improve = 0
        for i in range(1, n):
            if t[i] > t[i - 1]:
                has_trend = True
            if has_trend:
                if t[i] > best_val:
                    best_val = t[i]
                    best_idx = i
                    no_improve = 0
                else:
                    no_improve += 1
                    if no_improve > patience:
                        return best_idx if best_idx != -1 else 0
        return best_idx if best_idx != -1 else 0

    else:  # kind == "min"
        best_val = torch.inf
        best_idx = -1
        has_trend = False
        no_improve = 0
        for i in range(1, n):
            if t[i] < t[i - 1]:
                has_trend = True
            if has_trend:
                if t[i] < best_val:
                    best_val = t[i]
                    best_idx = i
                    no_improve = 0
                else:
                    no_improve += 1
                    if no_improve > patience:
                        return best_idx if best_idx != -1 else 0
        return best_idx if best_idx != -1 else 0


# -----------------------------
# Performances container
# -----------------------------
class Performances:
    """Stores metrics across epochs and provides plotting/evaluation helpers."""

    def __init__(self):
        self.reset()

    def reset(self):
        self.train_loss: List[float] = []
        self.train_acc: List[float] = []
        self.test_acc: List[float] = []
        self.PC: List[Optional[float]] = []
        self.KPA: List[Optional[float]] = []
        # Always keep CND as a torch.Tensor shaped (neurons, epochs) or empty (0,0)
        self.CND: torch.Tensor = torch.empty(0, 0)

    def update(
        self,
        train_loss: float,
        train_acc: float,
        test_acc: float,
        performances_dict: dict,
        args,  # kept for API compatibility (unused here)
    ):
        self.train_loss.append(train_loss)
        self.train_acc.append(train_acc)
        self.test_acc.append(test_acc)

        self.PC.append(performances_dict.get("PC"))
        cnd = performances_dict.get("CND")
        if cnd is not None:
            # cnd is per-epoch (neurons,) or scalar; we accumulate by epoch then stack in torch_transformation
            cnd = cnd.flatten() if isinstance(cnd, torch.Tensor) else torch.tensor(cnd, dtype=torch.float32)
            # temporarily store as rows in a list-of-rows buffer on the instance
            if not hasattr(self, "_cnd_rows"):
                self._cnd_rows: List[torch.Tensor] = []
            self._cnd_rows.append(cnd)
        self.KPA.append(performances_dict.get("KPA"))

    def __len__(self):
        return len(self.train_acc)

    def torch_transformation(self):
        self.test_acc = to_tensor(self.test_acc)

        if not is_all_none(self.PC):
            # Keep as a tensor for consistency down the line
            self.PC = to_tensor([x for x in self.PC if x is not None]).tolist()  # keep as list of floats

        # Build CND tensor (neurons, epochs)
        rows = getattr(self, "_cnd_rows", [])
        if rows:
            CND_e_by_n = torch.stack(rows)  # (epochs, neurons)
            if CND_e_by_n.ndim == 1:
                CND_e_by_n = CND_e_by_n.unsqueeze(0)
            self.CND = CND_e_by_n.transpose(0, 1)  # -> (neurons, epochs)
        else:
            self.CND = torch.empty(0, 0)

    @staticmethod
    def moving_average(data: ArrayLike, window_size: int = 3) -> Optional[torch.Tensor]:
        return moving_average_1d(data, window_size)

    def evaluate_memorization_metrics(self, args, window_size: int = 1) -> None:
        """
        Print a comparison table with:
        - PC: epoch at first local MIN (on MA if window_size>1), accuracy at that epoch
        - CND (90th pct across neurons): epoch at first local MAX (MA if window_size>1), accuracy at that epoch
        - KPA: epoch at first local MAX (MA if window_size>1), accuracy at that epoch
        - Test_acc: epoch of global MAX test accuracy, and that accuracy
        """
        test_acc_t = to_tensor(self.test_acc)
        n_epochs = len(test_acc_t)

        def ma_and_offset(series: ArrayLike) -> Tuple[torch.Tensor, int]:
            """Return (possibly MA-smoothed) series and the epoch offset induced by MA."""
            t = to_tensor(series)
            if window_size and window_size > 1:
                ma = self.moving_average(t, window_size=window_size)
                return (ma if ma is not None else t, max(0, window_size - 1))
            return t, 0

        rows = []

        # Test_acc (global max)
        best_epoch = int(torch.argmax(test_acc_t))
        rows.append({
            "Technique": "Test_acc",
            "selected_epoch": best_epoch,
            "selected_accuracy": float(test_acc_t[best_epoch]),
        })

        # PC (local min)
        if not is_all_none(self.PC) and ("PC" in getattr(args, "metrics", ["PC"])):
            PC_series, offset = ma_and_offset(self.PC)
            pc_idx = int(first_local_extremum(PC_series, patience=10, kind="min"))
            pc_epoch = min(pc_idx + offset, n_epochs - 1)
            rows.append({
                "Technique": "PC",
                "selected_epoch": pc_epoch,
                "selected_accuracy": float(test_acc_t[pc_epoch]),
            })

        # CND 90th percentile across neurons (local max)
        if self.CND.numel() > 0 and ("CND" in getattr(args, "metrics", ["CND"])):

            n_layers = len(args.neurs_x_hid_lyr)
            starts = [cumulative_neurons_before(i, args.neurs_x_hid_lyr) for i in range(n_layers)]
            CND_last_layer = self.CND[starts[-1]:,:] 

            cnd_q90 = torch.quantile(CND_last_layer, 0.90, dim=0)  # per-epoch
            CND_series, offset = ma_and_offset(cnd_q90)
            cnd_idx = int(first_local_extremum(CND_series, patience=10, kind="max"))
            cnd_epoch = min(cnd_idx + offset, n_epochs - 1)
            rows.append({
                "Technique": "CND_q90",
                "selected_epoch": cnd_epoch,
                "selected_accuracy": float(test_acc_t[cnd_epoch]),
            })

        # KPA (local max)
        if not is_all_none(self.KPA) and ("KPA" in getattr(args, "metrics", ["KPA"])):
            KPA_series, offset = ma_and_offset(self.KPA)
            kpa_idx = int(first_local_extremum(KPA_series, patience=10, kind="min"))
            kpa_epoch = min(kpa_idx + offset, n_epochs - 1)
            rows.append({
                "Technique": "KPA",
                "selected_epoch": kpa_epoch,
                "selected_accuracy": float(test_acc_t[kpa_epoch]),
            })

        # Print table
        df = pd.DataFrame(rows, columns=["Technique", "selected_epoch", "selected_accuracy"])
        if not df.empty:
            df["selected_accuracy"] = df["selected_accuracy"].map(lambda x: round(float(x), 6))
        print(df.to_string(index=False))

    # -----------------------------
    # Plotting
    # -----------------------------
    def plot_performances(self, args):
        """
        Plot training curves and (optionally) CND summaries; then an early-stopping proxy plot.
        """
        # Accuracy/Loss curves
        epochs = list(range(1, len(self.train_acc) + 1))
        test_acc_t = to_tensor(self.test_acc)
        max_test_idx = int(torch.argmax(test_acc_t))

        plt.figure(figsize=(12, 8))
        ax1 = plt.gca()
        line1, = ax1.plot(epochs, self.train_acc, label='Train Accuracy', marker='o', linestyle='-')
        line2, = ax1.plot(epochs, test_acc_t, label='Test Accuracy', marker='o', linestyle='-')
        line5 = ax1.axvline(x=epochs[max_test_idx], color='red', linestyle='--', label='Max Test Accuracy')

        ax1.set_xlabel('Epoch', fontsize=16)
        ax1.set_ylabel('Accuracy', fontsize=16)
        ax1.grid(True, linestyle='--', alpha=0.6)
        ax1.tick_params(axis='both', labelsize=14)

        ax2 = ax1.twinx()
        line6, = ax2.plot(epochs, self.train_loss, label='Train Loss', marker='s', linestyle='-')
        ax2.set_ylabel('Train Loss', fontsize=16)
        ax2.tick_params(axis='both', labelsize=14)

        lines = [line1, line2, line5, line6]
        labels = ['Train Accuracy', 'Test Accuracy', 'Max Test Accuracy', 'Train Loss']
        if 'line4' in locals():
            lines.insert(3, line4)
            labels.insert(3, 'Polluted Accuracy Known')

        plt.legend(lines, labels, loc='best', fontsize=16)
        plt.xlim(epochs[0], epochs[-1])
        plt.tight_layout()
        plt.savefig(f"{args.results_dir}/training_curves_{args.timestamp}.png")
        plt.close()

        # CND per-layer summary (mean ± std bands)
        if "CND" in getattr(args, "metrics", []) and self.CND.numel() > 0:
            n_layers = len(args.neurs_x_hid_lyr)
            starts = [cumulative_neurons_before(i, args.neurs_x_hid_lyr) for i in range(n_layers)]
            ends = [cumulative_neurons_before(i + 1, args.neurs_x_hid_lyr) for i in range(n_layers)]

            plt.figure(figsize=(12, 8))
            for l_idx, (s, e) in enumerate(zip(starts, ends)):
                layer = self.CND[s:e]  # (neurons_in_layer, epochs)
                mean = layer.mean(dim=0)
                std = layer.std(dim=0)
                plt.plot(epochs, mean, label=f'Layer {l_idx}', linewidth=3)
                plt.fill_between(epochs, mean - std, mean + std, alpha=0.15)

            plt.xlabel('Epoch', fontsize=16)
            plt.ylabel('CND', fontsize=16)
            plt.grid(True, linestyle='--', alpha=0.6)
            plt.legend(loc='upper right', fontsize=16)
            plt.tight_layout()
            plt.savefig(f"{args.results_dir}/CND_layers_{args.timestamp}.png")
            plt.close()

        # Early stopping proxy plot (single window = 5)
        self._early_stopping_plot(
            test_acc=to_tensor(self.test_acc, torch.float32),
            CND=self.CND,
            KPA=to_tensor(self.KPA, torch.float32) if not is_all_none(self.KPA) else None,
            PC=to_tensor(self.PC, torch.float32) if not is_all_none(self.PC) else None,
            window=5,
            results_dir=args.results_dir,
            timestamp=args.timestamp,
        )

    def _early_stopping_plot(
        self,
        test_acc: torch.Tensor,
        CND: torch.Tensor,
        KPA: Optional[torch.Tensor],
        PC: Optional[torch.Tensor],
        window: int,
        results_dir: str,
        timestamp: str = "",
    ):
        """
        Plot early stopping proxy metrics (CND 90th pct, KPA, PC) vs test accuracy for a single window.
        """
        os.makedirs(results_dir, exist_ok=True)

        fig, ax1 = plt.subplots(figsize=(18, 8))
        ax2 = ax1.twinx()

        # Right axis: raw test accuracy
        ax2.plot(range(1, len(test_acc) + 1), test_acc, label="Test Accuracy", linestyle="-", linewidth=3.5)

        def ma(x: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
            return None if x is None else moving_average_1d(x, window)

        # CND 90th percentile across neurons per epoch
        cnd_q = None
        if CND.numel() > 0:
            if CND.ndim == 2:  # (neurons, epochs)
                cnd_q = torch.quantile(CND, 0.90, dim=0)
            elif CND.ndim == 1:  # (epochs,)
                cnd_q = CND

        CND_w = ma(cnd_q) if cnd_q is not None else None
        KPA_w = ma(KPA)
        PC_w = ma(PC)

        # X-axes aligned to the MA outputs
        def epochs_for(w_series: Optional[torch.Tensor]) -> List[int]:
            return list(range(window, window + len(w_series))) if w_series is not None else []

        if CND_w is not None:
            ax1.plot(epochs_for(CND_w), normalize_0_1(CND_w), label=f"CND (w={window})", linestyle=":", linewidth=3.5)
        if KPA_w is not None:
            ax1.plot(epochs_for(KPA_w), normalize_0_1(KPA_w), label=f"KPA (w={window})", linestyle="--", linewidth=3.5)
        if PC_w is not None:
            ax1.plot(epochs_for(PC_w), normalize_0_1(PC_w), label=f"PC (w={window})", linestyle="-.", linewidth=3.5)

        ax1.set_ylabel("Normalized Proxy Value", fontsize=22)
        ax2.set_ylabel("Test Accuracy", fontsize=22)
        ax1.set_xlabel("Epoch", fontsize=22)
        ax1.tick_params(axis='both', which='major', labelsize=18)
        ax2.tick_params(axis='both', which='major', labelsize=18)
        ax1.grid(True, linestyle="--", alpha=0.5)

        # Legend grouping
        lines1, labels1 = ax1.get_legend_handles_labels()
        lines2, labels2 = ax2.get_legend_handles_labels()
        lines, labels = lines1 + lines2, labels1 + labels2
        groups = defaultdict(list)
        for lab, ln in zip(labels, lines):
            if "CND" in lab:
                groups["CND"].append((lab, ln))
            elif "KPA" in lab:
                groups["KPA"].append((lab, ln))
            elif "PC" in lab:
                groups["PC"].append((lab, ln))
            else:
                groups["Other"].append((lab, ln))
        ordered = groups["CND"] + groups["KPA"] + groups["PC"] + groups["Other"]
        if ordered:
            labels, lines = zip(*ordered)
            ax1.legend(lines, labels, fontsize=18, loc="upper right", framealpha=1.0, bbox_to_anchor=(1.0, 1.0))

        plt.tight_layout(pad=2.0)
        suffix = f"_{timestamp}" if timestamp else ""
        plt.savefig(os.path.join(results_dir, f"early_stopping_metrics{suffix}.png"))
        plt.close()