# %% [markdown]
# # Optimizer Playground
# Checking that optimization procedure works

# %%
import copy
from dataclasses import dataclass, asdict
from functools import partial
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim

from symo.notebooks.plot_utils import default_rcparams, plot_metric_norms
from symo.experiments.models import Activation, MLP
from symo.experiments.mlp_groups import group_config_v2
from symo.data import mlp_teacher_data
from symo.metrics import compute_metrics
from symo.utils import InverseStepScheduler
from symo.experiments.utils import (
    mlp_kernel_init,
    mlp_bias_init,
    train_loop,
)

import symo.optim2 as symo_optim
import matplotlib.pyplot as plt

plt.rcParams |= default_rcparams(dpi=500)

Tensor = torch.Tensor
Data = tuple[Tensor, Tensor]


# %%

depth: int = 3


@dataclass(frozen=True)
class ExperimentConfig:
    # Common
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    seed: int = 2025

    # Data
    num_train_points = 5000
    num_test_points = 5000

    # num_train_points = 512
    # num_test_points = 512


@dataclass(frozen=True)
class ModelConfig:
    input_dim: int = 100
    hidden_dims: tuple[int, ...] = (70,) * depth
    output_dim: int = 40

    skip_every: int | None = None
    use_bias: bool = True
    use_bias_last: bool = True
    activation: Activation = "relu"


# %%

num_epochs = 1000
# num_epochs = 4000


@dataclass(frozen=True)
class SymoConfig:
    num_epochs: int = num_epochs
    grads_beta: float = 0.93
    factors_beta: float = 0.93
    grads_bias_corr: bool = True
    factors_bias_corr: bool = True
    update_correction: bool = False
    # momentum: float = 0.0
    damping: float = 1.0e-10
    lr: float = 1e-2


@dataclass(frozen=True)
class Symo2Config1:
    num_epochs: int = num_epochs
    grads_beta: float = 0.93
    sigma_g_beta: float = 0.93
    grads_bias_corr: bool = False
    sigma_g_bias_corr: bool = False
    damping: float = 4.265e-06
    lr: float = 1.76
    update_correction: bool = False


@dataclass(frozen=True)
class Symo2Config2(Symo2Config1):
    sigma_g_bias_corr: bool = True


@dataclass(frozen=True)
class Symo2Config3(Symo2Config1):
    grads_bias_corr: bool = True
    sigma_g_bias_corr: bool = True


@dataclass(frozen=True)
class AdamConfig:
    num_epochs: int = num_epochs
    learning_rate: float = 0.019


@dataclass(frozen=True)
class AdamWConfig:
    num_epochs: int = num_epochs
    learning_rate: float = 0.019
    betas: tuple[float, float] = (0.9, 0.999)
    eps: float = 1e-8
    weight_decay: float = 0.01


@dataclass(frozen=True)
class SGDConfig:
    num_epochs: int = num_epochs
    learning_rate: float = 0.01
    momentum: float = 0.9
    weight_decay: float = 0.0


# %%

cfg = ExperimentConfig()
model_cfg = ModelConfig()

# %%

# Set random seeds for reproducibility
torch.manual_seed(cfg.seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(cfg.seed)

dtype = torch.get_default_dtype()
device = torch.device(cfg.device)

# %%


def record_train_step(model: MLP, optimizer: torch.optim.Optimizer, data: Data):
    """Train step that also records metrics."""

    old_params = {name: param.clone() for name, param in model.named_parameters()}

    loss = train_step(model, optimizer, data)
    grads = {name: param.grad for name, param in model.named_parameters()}

    new_params = {name: param for name, param in model.named_parameters()}
    update = {name: (new_params[name] - old_params[name]) for name in new_params.keys()}

    metrics = compute_metrics(loss, new_params, grads, update)
    return loss, metrics


def train_step(
    model: MLP,
    optimizer: optim.Optimizer | optim.lr_scheduler.LRScheduler,
    data: Data,
):
    """Single training step."""
    scheduler = None

    if isinstance(optimizer, optim.lr_scheduler.LRScheduler):
        scheduler = optimizer
        optimizer = scheduler.optimizer

    model.train()
    x, y = data

    optimizer.zero_grad()
    pred = model(x)
    loss = nn.functional.mse_loss(pred, y)
    loss.backward()
    optimizer.step()

    if scheduler is not None:
        scheduler.step()

    return loss.item()


@torch.compile
def eval_loss(model: MLP, data: Data):
    """Evaluate loss on validation data."""
    model.eval()
    x, y = data

    with torch.no_grad():
        pred = model(x)
        loss = nn.functional.mse_loss(pred, y)

    return loss.item()


# # "Global" curvature

# %%


def create_mlp(config: ModelConfig, device: torch.device | None = None, dtype=None):
    cfg = asdict(config)
    mlp = MLP(kernel_init=mlp_kernel_init, bias_init=mlp_bias_init, **cfg)
    return mlp.to(device=device, dtype=dtype)


mlp_teacher = create_mlp(model_cfg, device=device, dtype=dtype)
mlp_student = create_mlp(model_cfg, device=device, dtype=dtype)

mlp_student_state = mlp_student.state_dict()

# %%

generator = torch.Generator().manual_seed(cfg.seed)
_, ((train_data, val_data)) = mlp_teacher_data(
    generator,
    mlp_teacher,
    num_train_points=cfg.num_train_points,
    num_test_points=cfg.num_test_points,
    device=device,
)

# %%


def symo_run(
    train_data: Data,
    val_data: Data,
    model: MLP,
    opt_cfg,
    opt_class,
):
    """Run training with Symo optimizer."""

    group_spec = group_config_v2(model, hid_group="S", inout_group="I", same=False)

    opt_dict = asdict(opt_cfg)
    opt_dict.pop("num_epochs")
    opt_dict.pop("lr")

    optimizer = opt_class(
        model.parameters(),
        groups_spec=group_spec,
        lr=opt_cfg.lr,
        **opt_dict,
    )

    scheduler = InverseStepScheduler(optimizer=optimizer)

    out = train_loop(
        model,
        train_data,
        val_data,
        scheduler,
        record_train_step,
        eval_loss,
        num_epochs=opt_cfg.num_epochs,
        record_metrics=True,
        print_output=True,
    )

    return out


# %%


def pytorch_optim_run(
    train_data: Data,
    val_data: Data,
    model: MLP,
    opt_cfg,
    opt_class,
):
    """Run training with PyTorch built-in optimizer."""

    opt_dict = asdict(opt_cfg)
    opt_dict.pop("num_epochs")

    if "learning_rate" in opt_dict:
        opt_dict["lr"] = opt_dict.pop("learning_rate")

    optimizer = opt_class(model.parameters(), **opt_dict)

    out = train_loop(
        model,
        train_data,
        val_data,
        optimizer,
        record_train_step,
        eval_loss,
        num_epochs=opt_cfg.num_epochs,
        record_metrics=False,
        print_output=True,
    )

    return out


# %%

# Create partial functions for different optimizers
symo1_run = partial(symo_run, opt_class=symo_optim.Symo)
symo2_run = partial(symo_run, opt_class=symo_optim.Symo2)
adam_run = partial(pytorch_optim_run, opt_class=optim.Adam)
adamw_run = partial(pytorch_optim_run, opt_class=optim.AdamW)
sgd_run = partial(pytorch_optim_run, opt_class=optim.SGD)


common_args = dict(
    train_data=train_data,
    val_data=val_data,
)

opt_configs = (
    ("SymO-1", (symo1_run, dict(opt_cfg=SymoConfig()))),
    # ("SymO-2.3", (symo2_run, dict(opt_cfg=Symo2Config3()))),
    # ("Adam", (adam_run, [*common_args, AdamConfig()])),
    ("AdamW", (adamw_run, dict(opt_cfg=AdamWConfig()))),
    # ("SGD", (sgd_run, [*common_args, SGDConfig()])),
)

# %%

model = mlp_teacher

outputs = []
for name, (run_fn, kwargs) in opt_configs:
    state = copy.deepcopy(mlp_student_state)
    mlp_student.load_state_dict(state)
    # model = torch.compile(mlp_student, fullgraph=True)
    model = mlp_student

    kwargs = common_args | kwargs | dict(model=model)
    print(f"\n{'='*60}")
    print(f">>> Running: {name}")
    print(f"{'='*60}")
    out = run_fn(**kwargs)
    outputs.append((name, out))

# %%

# Plot metrics if available
if outputs:
    _, _, metrics = outputs[0][1]
    if metrics is not None:
        fig, _ = plot_metric_norms(metrics, linewidth=0.2)
        fig.tight_layout()
        fig.show()


# %%

fontsize = 10
figsize = (5, 3)
ax_args = dict(alpha=0.5, linewidth=0.5)

fig, axes = plt.subplots(ncols=2, figsize=figsize)

for name, (losses, vals, _) in outputs:
    axes[0].plot(losses, label=name, **ax_args)
    axes[1].plot(vals, label=name, **ax_args)

losses_name = "Loss"
vals_name = "Validation Loss"

axes[0].set_ylabel(losses_name)
axes[1].set_ylabel(vals_name)

for ax in axes:
    ax.set_yscale("log")
    ax.set_xlabel("Iteration")
    ax.set_title(f"MLP {len(model_cfg.hidden_dims)}-layer '{model_cfg.activation}'")

axes[0].legend()
fig.tight_layout()
plt.show()

# %%
