# %% [markdown]
# # Optimizer Playground
# Checking that optimization procedure works

# %%
import copy
from dataclasses import dataclass, asdict
from functools import partial
from typing import Callable
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

from symo.notebooks.plot_utils import default_rcparams, plot_metric_norms
from symo.group import S, B, I, O, Eq, flat_shape
from symo.experiments.models import Activation, MLP
from symo.experiments.mlp_groups import group_config
from symo.data import mlp_teacher_data
from symo.metrics import compute_metrics
from symo.utils import InverseStepScheduler
from symo.experiments.utils import (
    mlp_kernel_init,
    mlp_bias_init,
    train_loop,
)

import symo.optim as symo_optim
import matplotlib.pyplot as plt

plt.rcParams |= default_rcparams(dpi=500)

Tensor = torch.Tensor
Data = tuple[Tensor, Tensor]


# %%

depth: int = 3


@dataclass(frozen=True)
class ExperimentConfig:
    # Common
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    seed: int = 2025

    # Data
    num_train_points = 5000
    num_test_points = 5000

    # num_train_points = 512
    # num_test_points = 512


@dataclass(frozen=True)
class ModelConfig:
    input_dim: int = 100
    hidden_dims: tuple[int, ...] = (70,) * depth
    output_dim: int = 40

    skip_every: int | None = None
    use_bias: bool = True
    use_bias_last: bool = True
    activation: Activation = "relu"


# %%

num_epochs = 1000
# num_epochs = 4000


@dataclass(frozen=True)
class SymoConfig:
    num_epochs: int = num_epochs
    grads_beta: float = 0.93
    factors_beta: float = 0.93
    grads_bias_corr: bool = True
    factors_bias_corr: bool = True
    update_correction: bool = False
    # momentum: float = 0.0
    damping: float = 1.0e-10
    lr: float = 1.0


@dataclass(frozen=True)
class Symo2Config1:
    num_epochs: int = num_epochs
    grads_beta: float = 0.93
    sigma_g_beta: float = 0.93
    grads_bias_corr: bool = False
    sigma_g_bias_corr: bool = False
    damping: float = 4.265e-06
    lr: float = 1.76
    update_correction: bool = False


@dataclass(frozen=True)
class Symo2Config2(Symo2Config1):
    sigma_g_bias_corr: bool = True


@dataclass(frozen=True)
class Symo2Config3(Symo2Config1):
    grads_bias_corr: bool = True
    sigma_g_bias_corr: bool = True


@dataclass(frozen=True)
class AdamConfig:
    num_epochs: int = num_epochs
    learning_rate: float = 0.019


@dataclass(frozen=True)
class AdamWConfig:
    num_epochs: int = num_epochs
    learning_rate: float = 0.019
    betas: tuple[float, float] = (0.9, 0.999)
    eps: float = 1e-8
    weight_decay: float = 0.01


@dataclass(frozen=True)
class SGDConfig:
    num_epochs: int = num_epochs
    learning_rate: float = 0.01
    momentum: float = 0.9
    weight_decay: float = 0.0


# %%

cfg = ExperimentConfig()
model_cfg = ModelConfig()


# %%

# Set random seeds for reproducibility
torch.manual_seed(cfg.seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(cfg.seed)

device = torch.device(cfg.device)

# %%


def record_train_step(model: MLP, optimizer: torch.optim.Optimizer, data: Data):
    """Train step that also records metrics."""

    old_params = {name: param.clone() for name, param in model.named_parameters()}

    loss = train_step(model, optimizer, data)
    grads = {name: param.grad for name, param in model.named_parameters()}

    new_params = {name: param for name, param in model.named_parameters()}
    update = {name: (new_params[name] - old_params[name]) for name in new_params.keys()}

    metrics = compute_metrics(loss, new_params, grads, update)
    return loss, metrics


def train_step(
    model: MLP,
    optimizer: optim.Optimizer | optim.lr_scheduler.LRScheduler,
    data: Data,
):
    """Single training step."""
    scheduler = None

    if isinstance(optimizer, optim.lr_scheduler.LRScheduler):
        scheduler = optimizer
        optimizer = scheduler.optimizer

    model.train()
    x, y = data

    optimizer.zero_grad()
    pred = model(x)
    loss = nn.functional.mse_loss(pred, y)
    loss.backward()
    optimizer.step()

    if scheduler is not None:
        scheduler.step()

    return loss.item()


@torch.compile
def eval_loss(model: MLP, data: Data):
    """Evaluate loss on validation data."""
    model.eval()
    x, y = data

    with torch.no_grad():
        pred = model(x)
        loss = nn.functional.mse_loss(pred, y)

    return loss.item()


# # "Global" curvature

# %%


def create_mlp(config: ModelConfig, device: torch.device | None = None):
    cfg = asdict(config)
    mlp = MLP(kernel_init=mlp_kernel_init, bias_init=mlp_bias_init, **cfg)
    if device is not None:
        return mlp
    return mlp.to(device)


mlp_teacher = create_mlp(model_cfg, device)
mlp_student = create_mlp(model_cfg, device)

mlp_student_state = mlp_student.state_dict()

# %%

generator = torch.Generator().manual_seed(cfg.seed)
_, ((train_data, val_data)) = mlp_teacher_data(
    generator,
    mlp_teacher,
    num_train_points=cfg.num_train_points,
    num_test_points=cfg.num_test_points,
    device=device,
)

# %%


def symo_run(
    train_data: Data,
    val_data: Data,
    model: MLP,
    opt_cfg,
    opt_class,
):
    """Run training with Symo optimizer."""

    group_spec = group_config(model, hid_group=B, inout_group=I, same=False)

    opt_dict = asdict(opt_cfg)
    opt_dict.pop("num_epochs")
    opt_dict.pop("lr")

    optimizer = opt_class(
        model.named_parameters(),
        groups_spec=group_spec,
        lr=opt_cfg.lr,
        **opt_dict,
    )

    scheduler = InverseStepScheduler(optimizer=optimizer)

    out = train_loop(
        model,
        train_data,
        val_data,
        scheduler,
        record_train_step,
        eval_loss,
        num_epochs=opt_cfg.num_epochs,
        record_metrics=True,
        print_output=True,
    )

    return out


# %%


def pytorch_optim_run(
    train_data: Data,
    val_data: Data,
    model: MLP,
    opt_cfg,
    opt_class,
):
    """Run training with PyTorch built-in optimizer."""

    opt_dict = asdict(opt_cfg)
    opt_dict.pop("num_epochs")

    if "learning_rate" in opt_dict:
        opt_dict["lr"] = opt_dict.pop("learning_rate")

    optimizer = opt_class(model.parameters(), **opt_dict)

    out = train_loop(
        model,
        train_data,
        val_data,
        optimizer,
        record_train_step,
        eval_loss,
        num_epochs=opt_cfg.num_epochs,
        record_metrics=False,
        print_output=True,
    )

    return out


# %%

# Create partial functions for different optimizers
symo1_run = partial(symo_run, opt_class=symo_optim.Symo)
symo2_run = partial(symo_run, opt_class=symo_optim.Symo2)
adam_run = partial(pytorch_optim_run, opt_class=optim.Adam)
adamw_run = partial(pytorch_optim_run, opt_class=optim.AdamW)
sgd_run = partial(pytorch_optim_run, opt_class=optim.SGD)


common_args = dict(
    train_data=train_data,
    val_data=val_data,
)

opt_configs = (
    ("SymO-1", (symo1_run, dict(opt_cfg=SymoConfig()))),
    # ("SymO-2.3", (symo2_run, dict(opt_cfg=Symo2Config3()))),
    # ("Adam", (adam_run, [*common_args, AdamConfig()])),
    ("AdamW", (adamw_run, dict(opt_cfg=AdamWConfig()))),
    # ("SGD", (sgd_run, [*common_args, SGDConfig()])),
)

# %%

model = mlp_teacher

outputs = []
for name, (run_fn, kwargs) in opt_configs:
    state = copy.deepcopy(mlp_student_state)
    mlp_student.load_state_dict(state)
    # model = torch.compile(mlp_student, fullgraph=True)
    model = mlp_student

    kwargs = common_args | kwargs | dict(model=model)
    print(f"\n{'='*60}")
    print(f">>> Running: {name}")
    print(f"{'='*60}")
    out = run_fn(**kwargs)
    outputs.append((name, out))

# %%

# Plot metrics if available
if outputs:
    _, _, metrics = outputs[0][1]
    if metrics is not None:
        fig, _ = plot_metric_norms(metrics, linewidth=0.2)
        fig.tight_layout()
        fig.show()


# %%

fontsize = 10
figsize = (5, 3)
ax_args = dict(alpha=0.5, linewidth=0.5)

fig, axes = plt.subplots(ncols=2, figsize=figsize)

for name, (losses, vals, _) in outputs:
    axes[0].plot(losses, label=name, **ax_args)
    axes[1].plot(vals, label=name, **ax_args)

losses_name = "Loss"
vals_name = "Validation Loss"

axes[0].set_ylabel(losses_name)
axes[1].set_ylabel(vals_name)

for ax in axes:
    ax.set_yscale("log")
    ax.set_xlabel("Iteration")
    ax.set_title(f"MLP {len(model_cfg.hidden_dims)}-layer '{model_cfg.activation}'")

axes[0].legend()
fig.tight_layout()
plt.show()

# %%
