import torch
import torch.nn as nn
import numpy as np
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass

from .config import ExperimentConfig
from .model import MSNWedge2D, create_model
from .dataset import WedgeDataLoader


@dataclass
class TrainingLogs:
    steps: List[int]
    losses: List[float]
    arc_losses: List[float]
    edge_losses: List[float]
    constraint_losses: List[float]
    mus: List[List[float]]
    coeffs: List[List[float]]
    lr_mus: List[float]
    w_constraints: List[float]

    def to_dict(self) -> dict:
        return {
            "steps": self.steps,
            "losses": self.losses,
            "arc_losses": self.arc_losses,
            "edge_losses": self.edge_losses,
            "constraint_losses": self.constraint_losses,
            "mus": self.mus,
            "coeffs": self.coeffs,
            "lr_mus": self.lr_mus,
            "w_constraints": self.w_constraints,
        }


def get_constraint_weight(
    step: int,
    warmup_steps: int,
    ramp_steps: int,
    max_weight: float,
) -> float:
    if step < warmup_steps:
        return 0.0
    elif step < warmup_steps + ramp_steps:
        progress = (step - warmup_steps) / ramp_steps
        return max_weight * progress
    else:
        return max_weight


def get_lr_mu(
    step: int,
    warmup_steps: int,
    lr_warmup: float,
    lr_main: float,
) -> float:
    if step < warmup_steps:
        return lr_warmup
    else:
        return lr_main


def train_model(
    model: MSNWedge2D,
    config: ExperimentConfig,
    verbose: bool = False,
) -> Tuple[MSNWedge2D, TrainingLogs]:
    device = config.device

    mu_params = list(model.exps.parameters())
    mu_param_ids = {id(p) for p in mu_params}
    other_params = [p for p in model.parameters() if id(p) not in mu_param_ids]

    opt_coeff = torch.optim.Adam(other_params, lr=config.lr_coeff)
    opt_mu = torch.optim.Adam(mu_params, lr=config.lr_mu_warmup)

    data_loader = WedgeDataLoader(
        omega=config.omega,
        true_mu=config.true_mu,
        bc_type=config.bc_type,
        n_arc=config.n_arc,
        n_edge=config.n_edge,
        r_min=config.r_min,
        r_power=config.r_sample_power,
        base_seed=config.seed * 10000,
        device=device,
    )

    logs = TrainingLogs(
        steps=[],
        losses=[],
        arc_losses=[],
        edge_losses=[],
        constraint_losses=[],
        mus=[],
        coeffs=[],
        lr_mus=[],
        w_constraints=[],
    )

    for step in range(1, config.total_steps + 1):
        lr_mu = get_lr_mu(
            step, config.warmup_steps, config.lr_mu_warmup, config.lr_mu_main
        )
        for pg in opt_mu.param_groups:
            pg["lr"] = lr_mu

        w_constraint = get_constraint_weight(
            step,
            config.warmup_steps,
            config.constraint_ramp_steps,
            config.w_constraint_max,
        )

        batch = data_loader.get_batch()

        u_arc = model.forward(batch["r_arc"], batch["theta_arc"])
        arc_loss = torch.mean((u_arc - batch["target_arc"]) ** 2)

        u_e0 = model.forward(batch["r_e0"], batch["theta_e0"])
        edge0_loss = torch.mean(u_e0**2)

        dirichlet_loss_w, neumann_loss_w = model.edge_bc_loss_theta_omega(
            batch["r_ew"], config.omega
        )

        if model.second_edge_dirichlet:
            edge_w_loss = dirichlet_loss_w
        else:
            edge_w_loss = neumann_loss_w

        edge_loss = edge0_loss + edge_w_loss

        constraint_loss = model.constraint_loss(config.omega)

        small_mu_loss = model.small_mu_preference_loss()

        l1_loss = torch.mean(torch.abs(model.coeffs))

        loss = (
            config.w_arc * arc_loss
            + config.w_edge * edge_loss
            + w_constraint * constraint_loss
            + config.w_mode_prefer_small * small_mu_loss
            + config.w_l1 * l1_loss
        )

        if not torch.isfinite(loss):
            if verbose:
                print(f"[Step {step}] NaN loss, skipping")
            continue

        opt_coeff.zero_grad(set_to_none=True)
        opt_mu.zero_grad(set_to_none=True)
        loss.backward()

        torch.nn.utils.clip_grad_norm_(other_params, config.grad_clip_coeff)
        torch.nn.utils.clip_grad_norm_(mu_params, config.grad_clip_mu)

        opt_coeff.step()
        opt_mu.step()

        if step % config.log_every == 0 or step == 1:
            logs.steps.append(step)
            logs.losses.append(float(loss.item()))
            logs.arc_losses.append(float(arc_loss.item()))
            logs.edge_losses.append(float(edge_loss.item()))
            logs.constraint_losses.append(float(constraint_loss.item()))
            logs.mus.append(model.get_exponents().tolist())
            logs.coeffs.append(model.get_coeffs().tolist())
            logs.lr_mus.append(lr_mu)
            logs.w_constraints.append(w_constraint)

            if verbose:
                mus = model.get_exponents()
                dominant_mu = model.get_dominant_mu()
                print(
                    f"[Step {step}] loss={loss.item():.3e} arc={arc_loss.item():.3e} "
                    f"edge={edge_loss.item():.3e} constraint={constraint_loss.item():.3e} "
                    f"dominant_μ={dominant_mu:.4f}"
                )

    return model, logs


def run_single_experiment(
    config: ExperimentConfig,
    verbose: bool = False,
) -> Tuple[MSNWedge2D, TrainingLogs]:
    torch.manual_seed(config.seed)
    np.random.seed(config.seed)

    model = create_model(config)

    model, logs = train_model(model, config, verbose)

    return model, logs
