import os, time, statistics
from pathlib import Path
import csv

import numpy as np
import torch
from torch.utils.data import Dataset,DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import sys, contextlib
from io import StringIO
from loguru import logger

from custom_mamba import MambaBlock


def log_cell_output(
        path="output.log",          
        rotation="10 MB",           
        fmt="{time} | {message}",  
    ):
    """
    Usage:
        with log_cell_output("run1.txt"):
            # your code here
            print("hello")
    """
    @contextlib.contextmanager
    def _cm():
        # Add a temporary handler – keep its ID so we can close later
        handler_id = logger.add(path, rotation=rotation, format=fmt, enqueue=True)

        # Capture stdout/stderr in a buffer
        orig_out, orig_err   = sys.stdout, sys.stderr
        buffer               = StringIO()
        sys.stdout = sys.stderr = buffer

        try:
            yield           
        finally:
            # Flush buffer to Loguru
            contents = buffer.getvalue()
            if contents:
                for line in contents.splitlines():
                    if line.strip():
                        logger.info(line)


            sys.stdout, sys.stderr = orig_out, orig_err
            print(contents, end="")          

            logger.remove(handler_id)

    return _cm()   # returns a real context-manager object



def generate_mv_data(
    N: int = 100,         # number of sequences
    L: int = 100,         # sequence length
    d: int = 10,          # feature dimension
    tau: float = 0.01,    # noise std-dev for all tokens
    alpha_r: float = 0.2, # fraction of relevant tokens
    alpha_c: float = 0.1, # fraction of confusion tokens
    seed: int = 42
):
    """
    Majority-Voting data

    Labels y ∈ {+1, -1}:
      y = +1 → relevant uses noisy o_+, confusion uses noisy o_-.
      y = -1 → relevant uses noisy o_-, confusion uses noisy o_+.
    Irrelevant tokens are i.i.d. Gaussian noise N(0, tau^2 I_d).

    Returns
    -------
    X : np.ndarray, shape (N, L, d)
    y : np.ndarray, shape (N,)
    """
    if not (0 <= alpha_r <= 1 and 0 <= alpha_c <= 1 and alpha_r + alpha_c <= 1):
        raise ValueError("Require 0 ≤ alpha_r, alpha_c and alpha_r + alpha_c ≤ 1.")

    rng = np.random.default_rng(seed)

    # Orthonormal prototypes (rows)
    O = np.linalg.qr(rng.standard_normal((d, d)))[0].T
    o_plus, o_minus = O[0], O[1]

    # Integer block sizes
    r = int(round(alpha_r * L))
    c = int(round(alpha_c * L))
    r = min(r, L)
    c = min(c, L - r)
    u = L - r - c

    X = np.zeros((N, L, d), dtype=np.float32)
    y = np.zeros(N, dtype=np.float32)

    for n in range(N):
        lbl = rng.choice([-1, 1])
        y[n] = lbl
        maj  = o_plus if lbl == 1 else o_minus
        mino = o_minus if lbl == 1 else o_plus

        # relevant features
        if r > 0:
            X[n, 0:r] = rng.normal(loc=maj,  scale=tau, size=(r, d))
        # confusion features
        if c > 0:
            X[n, r:r+c] = rng.normal(loc=mino, scale=tau, size=(c, d))
        # irrelevant features
        if u > 0:
            X[n, r+c:L] = rng.normal(loc=0.0, scale=tau, size=(u, d))

    return X, y





def generate_locality_data(
        N: int = 100,          # number of sequences
        L: int = 100,          # sequence length
        d: int = 10,           # feature dimension
        tau: float = 0.0,      # noise std-dev
        # distance constraints   (0 < d_close < d_far < L)
        Delta_p_close: int = 2,     # gap of o_+ pair   in positive samples
        Delta_p_far:   int = 20,    # gap of o_- pair   in positive samples
        Delta_n_close: int = 2,     # gap of o_- pair   in negative samples
        Delta_n_far:   int = 20,    # gap of o_+ pair   in negative samples
        seed: int = 42             # RNG seed
    ):
    """
    Synthetic dataset where both prototypes appear twice.

        • positive (+1):  o_+ tokens are the CLOSE pair   (distance = Delta_p_close)
                          o_- tokens are the FAR   pair   (distance = Delta_p_far)
        • negative (–1):  o_- tokens are the CLOSE pair   (distance = Delta_n_close)
                          o_+ tokens are the FAR   pair   (distance = Delta_n_far)

    """
    # ----- safety check --------------------------------------------------------
    if not (Delta_p_close < Delta_p_far < L and
            Delta_n_close < Delta_n_far < L):
        raise ValueError("Need 0 < d_close < d_far < L for both classes.")

    rng = np.random.default_rng(seed)

    # Orthonormal basis  (rows = prototypes o_1 … o_d)
    O = np.linalg.qr(rng.standard_normal((d, d)))[0].T
    o_plus, o_minus = O[0], O[1]

    X = np.zeros((N, L, d), dtype=np.float32)
    y = np.zeros(N, dtype=np.float32)

    for n in range(N):
        label = rng.choice([-1, 1])
        y[n] = label

        # pick distances and prototypes for this sample 
        if label == 1:    # positive
            d_close, d_far = Delta_p_close, Delta_p_far
            close_proto, far_proto = o_plus, o_minus
        else:             # negative
            d_close, d_far = Delta_n_close, Delta_n_far
            close_proto, far_proto = o_minus, o_plus

        # choose FAR pair start
        s_far = rng.integers(0, L - d_far)          # first far token index
        s_far_pair = s_far + d_far                  # second far token index

        low  = s_far + 1
        high = s_far_pair - d_close - 1               
        if low > high:
            raise RuntimeError("No room for close pair; adjust deltas or L.")
        s_close = rng.integers(low, high + 1)
        s_close_pair = s_close + d_close

        occupied = {s_far, s_far_pair, s_close, s_close_pair}

        # inject discriminative tokens -----------------------------------------
        X[n, [s_close, s_close_pair]] = rng.normal(loc=close_proto,
                                                   scale=tau,
                                                   size=(2, d))
        X[n, [s_far,   s_far_pair  ]] = rng.normal(loc=far_proto,
                                                   scale=tau,
                                                   size=(2, d))

        # fill the rest with irrelevant noise -------------------------------
        filler_idx = np.setdiff1d(np.arange(L), list(occupied), assume_unique=True)
        X[n, filler_idx] = rng.normal(loc=0.0,
                                      scale=tau,
                                      size=(len(filler_idx), d))

    return X, y

class SyntheticTokenDataset(Dataset):
    def __init__(self, X, y):
        super().__init__()
        self.X = torch.tensor(X, dtype=torch.float32)  # [N, L, d]
        self.y = torch.tensor(y, dtype=torch.float32)  # [N]

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]
    
class MambaBlock(nn.Module):
    """
    Simplified Mamba SSM (+ selection)
    """

    def __init__(self, d: int, state_dim: int | None = None, learn_W: bool = False):
        super().__init__()
        self.d   = d
        self.N   = state_dim or d            # often we just set N = d
        self.w_delta = nn.Parameter(torch.zeros(d))   # trainable gating vector

        # W_B and W_C 
        W_init = torch.eye(d, self.N)
        if learn_W:
            self.W_B = nn.Parameter(W_init.clone())   
            self.W_C = nn.Parameter(W_init.clone())
        else:
            self.register_buffer("W_B", W_init, persistent=False)
            self.register_buffer("W_C", W_init, persistent=False)

    def forward(self, x):                  # x : [B, L, d]
        B, L, d = x.shape
        assert d == self.d, "input dim mismatch"


        b_seq = torch.matmul(x, self.W_B)          # [B, L, N]
        c_seq = torch.matmul(x, self.W_C)          # [B, L, N]

        sigma_seq = torch.sigmoid(torch.matmul(   # [B, L, 1]
            x, self.w_delta)
        ).unsqueeze(-1)                            # keep last dim

        H = torch.zeros(B, self.N, d, device=x.device)   

        outputs = []
        for t in range(L):
            sigma_t = sigma_seq[:, t, :].unsqueeze(-1)    # [B, 1, 1]
            alpha_t = 1.0 - sigma_t                      # [B, 1, 1]

            b_t     = b_seq[:, t].unsqueeze(-1)    # [B, N, 1]
            x_t     = x[:, t].unsqueeze(1)         # [B, 1, d]

            H = alpha_t * H + b_t @ x_t            # rank-1 update → [B, N, d]

            y_t = torch.bmm(H.transpose(1, 2),     # H^T_t  : [B, d, N]
                             c_seq[:, t].unsqueeze(-1)  # c_t : [B, N, 1]
                            ).squeeze(-1)          # → [B, d]
            outputs.append(y_t.unsqueeze(1))

        return torch.cat(outputs, dim=1)           # [B, L, d]
    
class MambaClassifier(nn.Module):
    def __init__(self, d=128, m=1000, L=30, init_std=0.1):
        super().__init__()
        self.block = MambaBlock(d)                 # Mamba block
        self.fc_pos = nn.Linear(d, m // 2)
        self.fc_neg = nn.Linear(d, m // 2)

        # Gaussian N(0, 0.1^2) and zero biases
        nn.init.normal_(self.fc_pos.weight, mean=0.0, std=init_std)
        nn.init.normal_(self.fc_neg.weight, mean=0.0, std=init_std)
        nn.init.zeros_(self.fc_pos.bias)
        nn.init.zeros_(self.fc_neg.bias)

    def forward(self, x):                          # x : [B, L, d]
        Y = self.block(x)                          # [B, L, d]  
        H_pos = F.relu(self.fc_pos(Y))             # [B, L, m/2]
        H_neg = F.relu(self.fc_neg(Y))             # [B, L, m/2]

        logits = H_pos.mean(2) - H_neg.mean(2)     # [B, L]
        return logits.mean(1)                      # [B]s 

def hinge_loss(logits, labels):
    # labels in {+1, -1}
    return torch.mean(torch.clamp(1 - labels * logits, min=0.0))


def evaluate(model, dataset, verbose=True):
    model.eval()
    with torch.no_grad():
        X, y = dataset.X.to(device), dataset.y.to(device)
        logits = model(X)
        loss = hinge_loss(logits, y).item()
        preds = torch.sign(logits)
        acc = (preds == y).float().mean().item() * 100
        if verbose:
            print(f"Test Loss: {loss:.4f} | Test Acc: {acc:.2f}%")
    return loss, acc

def log_metrics(path: Path, header, rows):
    """Append rows to a CSV; create with header if it doesn't exist."""
    new_file = not path.exists()
    with path.open("a", newline="") as f:
        w = csv.writer(f)
        if new_file:
            w.writerow(header)
        w.writerows(rows)

def train_until_threshold(model, train_set, test_set, epochs=30, batch_size=2048, lr=1e-2,
                          threshold=99.0, test_loss_threshold=1e-3, verbose=True,
                          early_patience=50,       # stop failing runs if no val-loss improvement
                          min_delta=1e-4,          # required improvement in test loss
                          min_epochs=20):          # don't even consider stopping before this
    model.train()
    loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
    optimizer = optim.Adam(model.parameters(), lr=lr)

    epoch_reached = None
    history = []

    # --- plateau tracking ---
    best_test_loss = float('inf')
    best_test_epoch = 0
    no_improve_epochs = 0

    for epoch in range(epochs):
        total_loss, correct, total = 0.0, 0, 0

        for X_batch, y_batch in loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)

            logits = model(X_batch)
            loss = hinge_loss(logits, y_batch)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * X_batch.size(0)
            preds = torch.sign(logits)
            correct += (preds == y_batch).sum().item()
            total += X_batch.size(0)

        # Evaluate on test set at this epoch
        test_loss, test_acc = evaluate(model, test_set, verbose=False)
        model.train()  # ← restore training mode after eval

        avg_loss = total_loss / total
        train_acc = correct / total * 100

        history.append((epoch+1, avg_loss, train_acc, test_loss, test_acc))

        if verbose:
            print(f"Epoch {epoch+1:2d} | Train Loss: {avg_loss:.4f} | Train Acc: {train_acc:.2f}% | "
                  f"Test Loss: {test_loss:.4f} | Test Acc: {test_acc:.2f}%")

        # --- update plateau bookkeeping ---
        if test_loss < best_test_loss - min_delta:
            best_test_loss = test_loss
            best_test_epoch = epoch + 1
            no_improve_epochs = 0
        else:
            no_improve_epochs += 1

        # SUCCESS early stop (existing)
        if test_acc >= threshold and test_loss <= test_loss_threshold:
            epoch_reached = epoch + 1
            print(f"Reached {threshold}% test accuracy at epoch {epoch_reached}. Stopping early.")
            break

        # --- FAILURE early stop on plateau ---
        if (epoch + 1) >= min_epochs and no_improve_epochs >= early_patience:
            # Mark as "did not hit threshold" by leaving epoch_reached as None
            print(
                f"Early stop (plateau): no val-loss improvement ≥ {min_delta} "
                f"for {no_improve_epochs} epochs (best at epoch {best_test_epoch}, "
                f"best loss {best_test_loss:.4f})."
            )
            break

    return epoch_reached, history

      
def run_trial(N_train, N_test=1000, alpha_r=0.25, alpha_c=0.1, epochs=30, batch_size=8,
              data_seed=42, torch_seed=0, threshold=99.0):

    torch.manual_seed(torch_seed)

    X_all, y_all = generate_mv_data(
        N       = N_train + N_test,
        L       = 30,
        d       = 32,
        tau     = 1e-1,
        alpha_r = alpha_r,
        alpha_c = alpha_c,
        seed    = data_seed
    )

    # X_all, y_all = generate_locality_data(
    #     N              = N_train + N_test,
    #     L              = 30,
    #     d              = 32,
    #     tau            = 0, # Changed tau later to 1e-6
    #     Delta_p_close  = Delta_L,
    #     Delta_p_far    = 28,
    #     Delta_n_close  = Delta_L,
    #     Delta_n_far    = 28,
    #     seed           = data_seed
    # )

    X_train, y_train = X_all[:N_train], y_all[:N_train]
    X_test,  y_test  = X_all[N_train:], y_all[N_train:]

    train_set = SyntheticTokenDataset(X_train, y_train)
    test_set  = SyntheticTokenDataset(X_test,  y_test)

    model = MambaClassifier(d=32, m=1000, L=30).to(device)

    epoch_reached, history = train_until_threshold(
        model, train_set, test_set,
        epochs=epochs, batch_size=batch_size, lr=1e-1,
        threshold=threshold
    )

    return dict(
        alpha_r=alpha_r,
        alpha_c=alpha_c,
        N_train   = N_train,
        epoch_reached = epoch_reached if epoch_reached is not None else epochs,
        hit_threshold = int(epoch_reached is not None),
        final_acc = history[-1][4],  # last test_acc
        history   = history
    )

def sweep_experiment(
    alpha_r_list,
    N_list,
    alpha_c=0.1,
    runs=5,
    epochs=30,
    OUT_DIR=Path("NumericalExperiments"),
    run_tag="mv_frac"
):
    OUT_DIR.mkdir(exist_ok=True)
    rows = []

    config_path = OUT_DIR / f"{run_tag}_config.txt"
    with log_cell_output(config_path):
        print(f"Running experiment with runs={runs}, epochs={epochs}")
        print(f"alpha_r_list={alpha_r_list}, alpha_c={alpha_c}, N_list={N_list}")
        print(f"Output directory: {OUT_DIR}")

        for N_train in N_list:
            for alpha_r in alpha_r_list:
                config_start_time = time.time()
                epoch_counts = []   
                hit_flags    = []   # 1 if success, 0 otherwise

                for r in range(runs):
                    print(f"Run {r+1}/{runs} | N={N_train} | alpha_r={alpha_r:.2f} | alpha_c={alpha_c:.2f}")
                    result = run_trial(
                        N_train=N_train,
                        alpha_r=alpha_r,
                        alpha_c=alpha_c,
                        epochs=epochs,
                        data_seed=42 + r,
                        torch_seed=42
                    )
                    epoch_counts.append(result["epoch_reached"])
                    hit_flags.append(result["hit_threshold"])

                config_runtime = time.time() - config_start_time
                print(f"Configuration N={N_train}, alpha_r={alpha_r:.2f} completed in "
                      f"{config_runtime:.2f}s ({config_runtime/60:.2f}m)")

                # ---- success rate ----
                successes    = int(sum(hit_flags))
                total_runs   = int(runs)
                success_rate = successes / total_runs if total_runs > 0 else float('nan')

                # ---- success-only epoch stats ----
                success_epochs = [e for e, h in zip(epoch_counts, hit_flags) if h == 1]
                if success_epochs:
                    mean_s   = float(np.mean(success_epochs))
                    std_s    = float(np.std(success_epochs, ddof=1)) if len(success_epochs) > 1 else 0.0
                    stderr_s = float(std_s / np.sqrt(len(success_epochs))) if len(success_epochs) > 1 else 0.0
                else:
                    mean_s = std_s = stderr_s = float('nan')

                # ---- all-trials epoch stats (failures already counted as 'epochs') ----
                mean_all   = float(np.mean(epoch_counts))
                std_all    = float(np.std(epoch_counts, ddof=1)) if len(epoch_counts) > 1 else 0.0
                stderr_all = float(std_all / np.sqrt(len(epoch_counts))) if len(epoch_counts) > 1 else 0.0

                rows.append({
                    "N_train": N_train,
                    "alpha_r": alpha_r,
                    "alpha_c": alpha_c,
                    "successes": successes,
                    "total_runs": total_runs,
                    "success_rate": success_rate,
                    "mean_epochs": mean_s,
                    "std_epochs": std_s,
                    "stderr_epochs": stderr_s,
                    "mean_epochs_all": mean_all,
                    "std_epochs_all": std_all,
                    "stderr_epochs_all": stderr_all,
                })

        out_path = OUT_DIR / f"{run_tag}_summary.csv"
        log_metrics(out_path, header=rows[0].keys(), rows=[r.values() for r in rows])
        print(f"Wrote: {out_path}")

    return rows



if __name__ == "__main__":
    script_start_time = time.time()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    print(f"Script started at: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(script_start_time))}")

    # Sweep settings
    alpha_r_list = [0.25, 0.30, 0.35, 0.40, 0.45, 0.50]
    alpha_c      = 0.20
    N_list       = [50,100,200]
    runs         = 20
    epochs       = 100
    run_tag      = "mv_frac_Exp2_run"

    total_configs = len(alpha_r_list) * len(N_list) * runs
    print(f"Total configurations to run: {total_configs}")
    print(f"Configuration: alpha_r_list={alpha_r_list}, alpha_c={alpha_c}, N_list={N_list}, runs={runs}, epochs={epochs}")

    summary = sweep_experiment(
        alpha_r_list=alpha_r_list,
        N_list=N_list,
        alpha_c=alpha_c,
        runs=runs,
        epochs=epochs,
        OUT_DIR=Path("NumericalExperiments"),
        run_tag=run_tag
    )

    # Timing
    script_end_time = time.time()
    total_runtime = script_end_time - script_start_time

    print("Experiment complete.")
    print(f"Script ended at: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(script_end_time))}")
    print(f"Total runtime: {total_runtime:.2f} seconds ({total_runtime/60:.2f} minutes)")
    if total_runtime > 3600:
        print(f"Total runtime: {total_runtime/3600:.2f} hours")
    print(f"Average time per configuration: {total_runtime/total_configs:.2f} seconds")