
# ================= MDS / TRAJECTORY HELPERS =================

def flatten_state_dict(
    model: torch.nn.Module,
    *,
    device: str = "cpu",
    only_trainable: bool = True,
    keys_allowlist: Optional[List[str]] = None,
) -> np.ndarray:
    """
    Flatten model parameters into a single 1D numpy vector.
    - detach() to avoid autograd graph
    - move to cpu by default
    - optionally restrict to a subset of parameter names via keys_allowlist
    """
    vec_parts = []
    named_params = model.named_parameters()

    for name, p in named_params:
        if only_trainable and (not p.requires_grad):
            continue
        if keys_allowlist is not None and name not in keys_allowlist:
            continue
        # Important: detach (no autograd), ensure contiguous, flatten
        t = p.detach()
        if device == "cpu":
            t = t.to("cpu")
        vec_parts.append(t.reshape(-1))

    if not vec_parts:
        raise ValueError("No parameters collected. Check only_trainable / keys_allowlist filters.")

    flat = torch.cat(vec_parts, dim=0).contiguous().numpy()
    return flat


class CheckpointCollector:
    """
    Collects (flattened_weights, loss, iteration) every `period` iterations.

    Notes on memory:
      - Storing full vectors can be huge. Use:
        * max_checkpoints
        * keys_allowlist to keep only some layers
        * larger period
    """
    def __init__(
        self,
        *,
        period: int,
        max_checkpoints: Optional[int] = None,
        keys_allowlist: Optional[List[str]] = None,
        random_state: int = 42,
        start_collect_after: int = 0,
    ):
        self.period = int(period)
        self.max_checkpoints = max_checkpoints
        self.keys_allowlist = keys_allowlist
        self.random_state = int(random_state)
        self.start_collect_after = int(start_collect_after)

        self.points: List[np.ndarray] = []
        self.losses: List[float] = []
        self.iters: List[int] = []

    def maybe_collect(self, model: torch.nn.Module, loss_value: float, global_iteration: int) -> None:
        if global_iteration < self.start_collect_after:
            return
        if self.period <= 0:
            return
        if global_iteration % self.period != 0:
            return

        if self.max_checkpoints is not None and len(self.points) >= self.max_checkpoints:
            return

        w = flatten_state_dict(model, device="cpu", only_trainable=True, keys_allowlist=self.keys_allowlist)
        self.points.append(w)
        self.losses.append(float(loss_value))
        self.iters.append(int(global_iteration))

    def finalize(self) -> Dict[str, np.ndarray]:
        if len(self.points) == 0:
            return {"points": np.empty((0, 0)), "losses": np.array([]), "iters": np.array([])}

        points = np.stack(self.points, axis=0)  # (n, p)
        losses = np.asarray(self.losses, dtype=np.float64)  # (n,)
        iters = np.asarray(self.iters, dtype=np.int64)
        return {"points": points, "losses": losses, "iters": iters}


def run_mds_and_plot(
    all_trajs: Dict[str, Dict[str, np.ndarray]],
    *,
    random_state: int = 42,
    max_total_points: int = 800,
    subsample_strategy: str = "even",  # "even" | "random"
) -> None:
    """
    all_trajs: dict optimizer_name -> {"points": (n_i,p), "losses": (n_i,), "iters": (n_i,)}

    MDS is O(n^2) in number of points. We cap total points and subsample.
    """
    # Filter empty
    filtered = {k: v for k, v in all_trajs.items() if v["points"].shape[0] > 0}
    if not filtered:
        print("[MDS] No trajectory points collected. Nothing to plot.")
        return

    # Build concatenated arrays + segment lengths
    opt_names = list(filtered.keys())
    lengths = [filtered[o]["points"].shape[0] for o in opt_names]

    # Sanity: all points must have same dimension p
    p_dims = [filtered[o]["points"].shape[1] for o in opt_names]
    if len(set(p_dims)) != 1:
        raise ValueError(f"Different parameter vector sizes across optimizers: {dict(zip(opt_names, p_dims))}")

    points = np.concatenate([filtered[o]["points"] for o in opt_names], axis=0)
    losses = np.concatenate([filtered[o]["losses"] for o in opt_names], axis=0)

    n_total = points.shape[0]
    if n_total == 0:
        print("[MDS] No points after concatenation.")
        return

    # Subsample if too many points (MDS O(n^2))
    if n_total > max_total_points:
        rng = np.random.RandomState(random_state)
        if subsample_strategy == "random":
            idx = rng.choice(n_total, size=max_total_points, replace=False)
            idx = np.sort(idx)
        else:
            # even spacing
            idx = np.linspace(0, n_total - 1, num=max_total_points).astype(int)

        points = points[idx]
        losses = losses[idx]

        # IMPORTANT: segmentation becomes invalid after global subsample.
        # To explain/avoid complexity: for multi-optimizer segmentation we should subsample per-optimizer, not globally.
        print("[MDS] Global subsample would break optimizer segments. Doing per-optimizer subsample instead...")

        # Rebuild per-optimizer with per-optimizer subsampling
        per_opt = {}
        for o in opt_names:
            pts = filtered[o]["points"]
            los = filtered[o]["losses"]
            n = pts.shape[0]
            # allocate proportional budget
            budget = max(2, int(max_total_points * (n / n_total)))
            budget = min(budget, n)
            if budget < n:
                if subsample_strategy == "random":
                    j = rng.choice(n, size=budget, replace=False)
                    j = np.sort(j)
                else:
                    j = np.linspace(0, n - 1, num=budget).astype(int)
                pts = pts[j]
                los = los[j]
            per_opt[o] = {"points": pts, "losses": los}

        # concatenate again
        opt_names = list(per_opt.keys())
        lengths = [per_opt[o]["points"].shape[0] for o in opt_names]
        points = np.concatenate([per_opt[o]["points"] for o in opt_names], axis=0)
        losses = np.concatenate([per_opt[o]["losses"] for o in opt_names], axis=0)

    # MDS to 2D
    mds = MDS(n_components=2, dissimilarity="euclidean", random_state=random_state)
    points_2d = mds.fit_transform(points)
    points_3d = np.column_stack([points_2d, losses])  # (n,3)

    # Plot
    fig = plt.figure(figsize=(10, 7))
    ax = fig.add_subplot(111, projection="3d")

    offset = 0
    for opt_name, n in zip(opt_names, lengths):
        seg = points_3d[offset:offset + n]
        if seg.shape[0] == 0:
            continue
        ax.plot(seg[:, 0], seg[:, 1], seg[:, 2], label=opt_name)
        ax.scatter(seg[0, 0], seg[0, 1], seg[0, 2], c="black", marker="x", s=60)
        offset += n

    ax.set_xlabel("MDS-1")
    ax.set_ylabel("MDS-2")
    ax.set_zlabel("Train loss (batch)")
    ax.set_title("Optimizer trajectories in MDS + loss space")
    ax.legend()
    plt.show()


def run_training_and_collect(
    *,
    model: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    scheduler,
    clipping,
    train_loader,
    eval_loader,
    test_loader,
    criterion,
    device: str,
    num_epochs: int,
    use_diagnostic_hook: bool,
    hook,
    collector: CheckpointCollector,
) -> Dict[str, object]:
    """
    Runs training loop and collects trajectory checkpoints via collector.
    Returns metrics + collected trajectory.
    """
    train_losses = []
    val_accuracies = []
    test_accuracies = []
    global_iteration = 0

    def evaluate(model, loader, device):
        model.eval()
        correct = 0
        total = 0
        total_loss = 0.0
        with torch.no_grad():
            for images, labels in loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                total_loss += loss.item() * images.size(0)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        avg_loss = total_loss / total
        accuracy = 100 * correct / total
        return accuracy, avg_loss

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0

        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
        for images, labels in pbar:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad(set_to_none=True)
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()

            global_iteration += 1

            if clipping is not None:
                torch.nn.utils.clip_grad_norm_(model.parameters(), clipping, norm_type="inf")

            optimizer.step()
            if scheduler:
                scheduler.step()

            # Collect AFTER optimizer.step(): weights correspond to "state after update"
            collector.maybe_collect(model, loss.item(), global_iteration)

            running_loss += loss.item()

            if use_diagnostic_hook:
                metrics = hook.compute_and_reset()
                metrics = {f"inner_metrics/{key}": metrics[key] for key in metrics}
                wandb.log(metrics)

            pbar.set_postfix({"loss": f"{loss.item():.4f}"})

        avg_loss = running_loss / len(train_loader)
        val_acc, val_loss = evaluate(model, eval_loader, device)
        test_acc, test_loss = evaluate(model, test_loader, device)

        wandb.log({
            "train_loss": avg_loss,
            "val_loss": val_loss,
            "test_loss": test_loss,
            "val_acc": val_acc,
            "test_acc": test_acc,
            "epoch": epoch + 1,
        })

        train_losses.append(avg_loss)
        val_accuracies.append(val_acc)
        test_accuracies.append(test_acc)

        print(
            f"Epoch [{epoch+1}/{num_epochs}] "
            f"Loss: {avg_loss:.4f} | Val Acc: {val_acc:.2f}% | Test Acc: {test_acc:.2f}%"
        )

    traj = collector.finalize()
    return {
        "train_losses": train_losses,
        "val_accuracies": val_accuracies,
        "test_accuracies": test_accuracies,
        "traj": traj,
    }


import wandb
import torch
import torch.nn as nn
import json
import copy
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from utils import set_global_seed
from cifar10_training import get_data, get_optimizer
import models
from lipschitz_tracker import LipschitzTracker
from optimizers.hooks import DiagnosticHook

from sklearn.manifold import MDS
from typing import Dict, List, Optional, Tuple

# ========== SETTINGS ==========
NAME = 'user1'
# FOR SIMPLECNN
# OPTIMIZER_NAMES = ['SoftSignumPT_not_decoupled_wd', 'SoftSignumPT', 'AdamW', 'Signum', 'Signum+SGD', 'Adam', 'SGD'] 
OPTIMIZER_NAMES = ['SoftSignumPT']

 
# OPTIMIZER_NAMES = ['SoftSignum_decoupled_wd', 'SoftSignumPT_not_decoupled_wd'] 


# FOR RESNET18_32x32
# OPTIMIZER_NAMES = ['Signum', 'Signum+SGD', 'SoftSignum', 'SoftSignumPT', 'AdamW']

MODEL_NAME = 'simplecnn'
DATASET_NAME = 'cifar10'
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
SEED = 42
NUM_EPOCHS = 50
BATCH_SIZE = 128
USE_AUGMENTATIONS = False


USE_DIAGNOSTIC_HOOK = True
SATURATION_THRESHOLD = 0.55
DAMPING_TOL = 1e-8


WANDB_PROJECT = ''
WANDB_ENTITY = ''


SAVE_JSON = False

# ====== Trajectory / MDS settings ======
CHECKPOINT_PERIOD = 200
MAX_CHECKPOINTS_PER_OPT = 300    
MDS_MAX_TOTAL_POINTS = 800       
MDS_RANDOM_STATE = 42


KEYS_ALLOWLIST = None
START_COLLECT_AFTER = 2000  




# ========== PREPARATION ==========
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

set_global_seed(SEED)

print(f"Using device: {DEVICE}")


train_loader, eval_loader, test_loader = get_data(
    batch_size=BATCH_SIZE, 
    seed=SEED, 
    use_augmentations=USE_AUGMENTATIONS
)
print(f"Train samples: {len(train_loader.dataset)}")
print(f"Val samples: {len(eval_loader.dataset)}")
print(f"Test samples: {len(test_loader.dataset)}")


for OPTIMIZER_NAME in OPTIMIZER_NAMES:
    print("\n" + "="*70)
    print(f"Starting training with optimizer: {OPTIMIZER_NAME}")
    print("="*70 + "\n")
    
    set_global_seed(SEED)
    

    try:
        with open(f'tuning/{NAME}/{DATASET_NAME}/{MODEL_NAME}/{OPTIMIZER_NAME}.json', 'r') as f:
            optimizer_params = json.load(f)
        optimizer_params.pop('val_score', None)
        optimizer_params.pop('test_score', None)
        hook = DiagnosticHook(
            saturation_threshold=SATURATION_THRESHOLD,
            damping_tol=DAMPING_TOL
        )
        if USE_DIAGNOSTIC_HOOK:
            optimizer_params['hook'] = hook
            
        print(f"Loaded parameters: {optimizer_params}")
    except FileNotFoundError:
        print(f"Warning: No tuned parameters found for {OPTIMIZER_NAME}, using defaults")
        optimizer_params = {
            'lr': 0.01,
            'momentum': 0.9,
            'weight_decay': 0.001,
            'tmin': 2.0,
            'tmax': 20.0,
            'warmup_iters': 0.8,  # Fraction of total iterations
        }
    
    optimizer_params['batch_size'] = BATCH_SIZE

    run = wandb.init(
        project=WANDB_PROJECT,
        entity=WANDB_ENTITY,
        name=f'{OPTIMIZER_NAME}_{DATASET_NAME}_{MODEL_NAME}',
        config={
            'optimizer': OPTIMIZER_NAME,
            'model': MODEL_NAME,
            'seed': SEED,
            'num_epochs': NUM_EPOCHS,
            'batch_size': BATCH_SIZE,
            'use_augmentations': USE_AUGMENTATIONS,
            **optimizer_params
        }
    )


    run.define_metric("epoch", hidden=True)
    run.define_metric("iteration", hidden=True)
    

    run.define_metric("*", step_metric="epoch")

    run.define_metric("grad_noise/*", step_metric="iteration")
    run.define_metric("tanh/*", step_metric="iteration")
    run.define_metric("inner_metrics/*", step_metric="iteration")  
  

    for key, value in optimizer_params.items():
        run.summary[f'optimizer/{key}'] = value
    run.summary['optimizer/name'] = OPTIMIZER_NAME


    if MODEL_NAME == 'simplecnn':
        model = models.SimpleCNN().to(DEVICE)
    elif MODEL_NAME == 'simplecnnbinclass':
        model = models.SimpleCNNBinClass().to(DEVICE)
    elif MODEL_NAME == 'resnet18_32x32':
        model = models.ResNet18_32x32().to(DEVICE)
    else:
        raise ValueError(f"Invalid model name: {MODEL_NAME}")

    print(f"Model: {MODEL_NAME}")


    n_iters = NUM_EPOCHS * len(train_loader)
    optimizer, (clipping, scheduler) = get_optimizer(
        OPTIMIZER_NAME, 
        model, 
        search_space=None, 
        trial=None, 
        optimizer_params=optimizer_params,
        n_iters=n_iters
    )
    print(f"Optimizer: {OPTIMIZER_NAME}")
    if clipping:
        print(f"Gradient clipping: {clipping}")
        wandb.config.update({'clipping': clipping})
    if scheduler:
        print(f"LR scheduler: {scheduler}")
        wandb.config.update({'scheduler': scheduler})

    criterion = nn.CrossEntropyLoss()
    
    collector = CheckpointCollector(
        period=CHECKPOINT_PERIOD,
        max_checkpoints=MAX_CHECKPOINTS_PER_OPT,
        keys_allowlist=KEYS_ALLOWLIST,
        random_state=MDS_RANDOM_STATE,
        start_collect_after=START_COLLECT_AFTER,
    )

    

    print("\n" + "="*50)
    print("Starting training...")
    print("="*50 + "\n")

    out = run_training_and_collect(
        model=model,
        optimizer=optimizer,
        scheduler=scheduler,
        clipping=clipping,
        train_loader=train_loader,
        eval_loader=eval_loader,
        test_loader=test_loader,
        criterion=criterion,
        device=DEVICE,
        num_epochs=NUM_EPOCHS,
        use_diagnostic_hook=USE_DIAGNOSTIC_HOOK,
        hook=hook,
        collector=collector,
    )

    train_losses = out["train_losses"]
    val_accuracies = out["val_accuracies"]
    test_accuracies = out["test_accuracies"]

    all_trajs[OPTIMIZER_NAME] = out["traj"]

    print(f"[Trajectory] Collected {all_trajs[OPTIMIZER_NAME]['points'].shape[0]} checkpoints for {OPTIMIZER_NAME}")
        

    

    print("\n" + "="*50)
    print("Training completed!")
    print("="*50)
    print(f"Best Val Accuracy: {max(val_accuracies):.2f}% at epoch {val_accuracies.index(max(val_accuracies))+1}")
    print(f"Best Test Accuracy: {max(test_accuracies):.2f}% at epoch {test_accuracies.index(max(test_accuracies))+1}")
    print(f"Final Val Accuracy: {val_accuracies[-1]:.2f}%")
    print(f"Final Test Accuracy: {test_accuracies[-1]:.2f}%")

    wandb.run.summary['best_val_acc'] = max(val_accuracies)
    wandb.run.summary['best_test_acc'] = max(test_accuracies)
    wandb.run.summary['final_val_acc'] = val_accuracies[-1]
    wandb.run.summary['final_test_acc'] = test_accuracies[-1]
    wandb.run.summary['best_val_epoch'] = val_accuracies.index(max(val_accuracies)) + 1
    wandb.run.summary['best_test_epoch'] = test_accuracies.index(max(test_accuracies)) + 1


    results = {
        'optimizer': OPTIMIZER_NAME,
        'model': MODEL_NAME,
        'train_losses': train_losses,
        'val_accuracies': val_accuracies,
        'test_accuracies': test_accuracies,
        'best_val_acc': max(val_accuracies),
        'best_test_acc': max(test_accuracies),
        'optimizer_params': optimizer_params
    }

    if SAVE_JSON:
        with open(f'training_log_{OPTIMIZER_NAME}_wandb.json', 'w') as f:
            json.dump(results, f, indent=2)
        print(f"\nResults saved to training_log_{OPTIMIZER_NAME}_wandb.json")

    wandb.finish()
    print(f"WandB run completed for {OPTIMIZER_NAME}!")
    
    
run_mds_and_plot(
    all_trajs,
    random_state=MDS_RANDOM_STATE,
    max_total_points=MDS_MAX_TOTAL_POINTS,
    subsample_strategy="even",  # or "random"
)

print("\n" + "="*70)
print(f"All {len(OPTIMIZER_NAMES)} optimizers completed successfully!")
print("="*70)
