import argparse
import json
import math
import os
import sys
import time
from dataclasses import dataclass, field
from typing import Dict, Optional, List

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from models import MSNWedge2D, build_optimizers
from utils import (
    set_seed,
    get_device,
    sample_arc_points,
    sample_edge_points,
    sample_interior_points,
    get_true_exponents,
    compute_arc_boundary_target,
    compute_exponent_error,
    compute_constraint_violation,
    compute_solution_l2_error,
)


@dataclass
class TrainingConfig:
    K: int = 6

    total_steps: int = 20000
    warmup_steps: int = 3000
    ramp_steps: int = 5000

    n_arc: int = 512
    n_edge: int = 256
    n_interior: int = 1024
    r_min: float = 1e-4
    r_power: float = 2.0

    lr_w: float = 1e-2
    lr_mu: float = 5e-4
    lr_mu_warmup: float = 0.0

    inner_steps: int = 3

    w_arc: float = 50.0
    w_edge_dirichlet: float = 20.0
    w_edge_neumann: float = 10.0
    w_constraint_max: float = 100.0
    w_sparse: float = 1e-4
    w_small_mu: float = 0.0

    grad_clip_w: float = 1.0
    grad_clip_mu: float = 0.2

    log_every: int = 500


def get_constraint_weight(
    step: int, config: TrainingConfig, use_constraint: bool
) -> float:
    if not use_constraint:
        return 0.0

    if step <= config.warmup_steps:
        return 0.0

    ramp_progress = (step - config.warmup_steps) / config.ramp_steps
    ramp_progress = min(1.0, max(0.0, ramp_progress))

    return config.w_constraint_max * ramp_progress


def get_lr_mu(step: int, config: TrainingConfig) -> float:
    if step <= config.warmup_steps:
        return config.lr_mu_warmup

    return config.lr_mu


def train_single_experiment(
    omega: float,
    bc_type: str,
    method: str,
    seed: int,
    config: TrainingConfig,
    device: torch.device,
    verbose: bool = True,
) -> Dict:
    set_seed(seed)
    bc_type = bc_type.upper()

    true_mus = get_true_exponents(omega, bc_type, n_modes=1)
    true_mu = true_mus[0]

    model = MSNWedge2D(
        K=config.K, omega=omega, bc_type=bc_type, use_adaptive_bounds=True
    ).to(device)

    opt_w, opt_mu = build_optimizers(model, lr_w=config.lr_w, lr_mu=config.lr_mu)

    history = {
        "step": [],
        "loss": [],
        "arc_loss": [],
        "edge_loss": [],
        "neumann_loss": [],
        "constraint_loss": [],
        "mus": [],
        "coeffs": [],
    }

    use_constraint = method == "constraint"

    edge0_dirichlet = bc_type in ["DD", "DN"]
    edgew_dirichlet = bc_type in ["DD", "ND"]
    edge0_neumann = bc_type in ["NN", "ND"]
    edgew_neumann = bc_type in ["NN", "DN"]

    start_time = time.time()

    for step in range(1, config.total_steps + 1):
        w_con = get_constraint_weight(step, config, use_constraint)
        lr_mu = get_lr_mu(step, config)

        for pg in opt_mu.param_groups:
            pg["lr"] = lr_mu

        step_seed = seed * 100000 + step

        xy_arc = sample_arc_points(config.n_arc, omega, seed=step_seed)
        theta_arc = np.arctan2(xy_arc[:, 1], xy_arc[:, 0])
        theta_arc = np.where(theta_arc < 0, theta_arc + 2 * np.pi, theta_arc)
        target_arc = compute_arc_boundary_target(theta_arc, omega, bc_type, mode=1)

        xy_arc_t = torch.tensor(xy_arc, device=device)
        target_arc_t = torch.tensor(target_arc, device=device)

        xy_e0 = sample_edge_points(
            config.n_edge,
            omega,
            "theta0",
            r_min=config.r_min,
            r_power=config.r_power,
            seed=step_seed + 1,
        )
        xy_ew = sample_edge_points(
            config.n_edge,
            omega,
            "thetaw",
            r_min=config.r_min,
            r_power=config.r_power,
            seed=step_seed + 2,
        )

        xy_e0_t = torch.tensor(xy_e0, device=device)
        xy_ew_t = torch.tensor(xy_ew, device=device)

        u_arc = model(xy_arc_t)
        u_e0 = model(xy_e0_t)
        u_ew = model(xy_ew_t)

        arc_loss = F.mse_loss(u_arc, target_arc_t)

        edge_dirichlet_loss = torch.tensor(0.0, device=device)
        if edge0_dirichlet:
            edge_dirichlet_loss = edge_dirichlet_loss + F.mse_loss(
                u_e0, torch.zeros_like(u_e0)
            )
        if edgew_dirichlet:
            edge_dirichlet_loss = edge_dirichlet_loss + F.mse_loss(
                u_ew, torch.zeros_like(u_ew)
            )

        edge_neumann_loss = torch.tensor(0.0, device=device)
        if edge0_neumann:
            du_dtheta_e0 = model.compute_angular_derivative(xy_e0_t)
            edge_neumann_loss = edge_neumann_loss + F.mse_loss(
                du_dtheta_e0, torch.zeros_like(du_dtheta_e0)
            )
        if edgew_neumann:
            du_dtheta_ew = model.compute_angular_derivative(xy_ew_t)
            edge_neumann_loss = edge_neumann_loss + F.mse_loss(
                du_dtheta_ew, torch.zeros_like(du_dtheta_ew)
            )

        constraint_loss = (
            model.constraint_loss()
            if use_constraint
            else torch.tensor(0.0, device=device)
        )

        sparse_loss = torch.mean(torch.abs(model.coeffs))

        small_mu_loss = torch.tensor(0.0, device=device)
        if config.w_small_mu > 0 and use_constraint:
            small_mu_loss = model.small_mu_preference_loss(strength=1.0)

        loss = (
            config.w_arc * arc_loss
            + config.w_edge_dirichlet * edge_dirichlet_loss
            + config.w_edge_neumann * edge_neumann_loss
            + w_con * constraint_loss
            + config.w_sparse * sparse_loss
            + config.w_small_mu * small_mu_loss
        )

        if not torch.isfinite(loss):
            if verbose:
                print(f"[Step {step}] Non-finite loss, skipping")
            continue

        opt_w.zero_grad(set_to_none=True)
        opt_mu.zero_grad(set_to_none=True)
        loss.backward()

        nn.utils.clip_grad_norm_(
            [p for n, p in model.named_parameters() if "exps" not in n],
            config.grad_clip_w,
        )
        nn.utils.clip_grad_norm_(list(model.exps.parameters()), config.grad_clip_mu)

        opt_w.step()
        if step > config.warmup_steps and step % config.inner_steps == 0:
            opt_mu.step()

        if step % config.log_every == 0 or step == 1:
            mus_np = model.get_exponents()
            coeffs_np = model.get_coeffs()

            history["step"].append(step)
            history["loss"].append(float(loss.item()))
            history["arc_loss"].append(float(arc_loss.item()))
            history["edge_loss"].append(float(edge_dirichlet_loss.item()))
            history["neumann_loss"].append(float(edge_neumann_loss.item()))
            history["constraint_loss"].append(
                float(constraint_loss.item()) if use_constraint else 0.0
            )
            history["mus"].append(mus_np.tolist())
            history["coeffs"].append(coeffs_np.tolist())

            if verbose:
                closest_idx = np.argmin(np.abs(mus_np - true_mu))
                closest_mu = mus_np[closest_idx]
                _, rel_err = compute_exponent_error(closest_mu, true_mu)

                phase = (
                    "WARMUP"
                    if step <= config.warmup_steps
                    else (
                        "RAMP"
                        if step <= config.warmup_steps + config.ramp_steps
                        else "TRAIN"
                    )
                )

                print(
                    f"[{phase:6s}|{method.upper():10s}] Step {step:5d} | Loss: {loss.item():.3e} | "
                    f"Arc: {arc_loss.item():.3e} | EdgeD: {edge_dirichlet_loss.item():.3e} | "
                    f"EdgeN: {edge_neumann_loss.item():.3e} | Con: {constraint_loss.item():.3e} | "
                    f"Closest mu: {closest_mu:.4f} (true: {true_mu:.4f}, err: {rel_err:.2f}%)"
                )

    training_time = time.time() - start_time

    model.eval()
    final_mus = model.get_exponents()
    final_coeffs = model.get_coeffs()

    dominant_idx = np.argmax(np.abs(final_coeffs))
    dominant_mu = final_mus[dominant_idx]
    dominant_coeff = final_coeffs[dominant_idx]

    abs_error, rel_error = compute_exponent_error(dominant_mu, true_mu)
    constraint_viol = compute_constraint_violation(
        final_mus, omega, bc_type, final_coeffs
    )

    l2_error, rel_l2_error = compute_solution_l2_error(
        model, omega, bc_type, mode=1, n_eval=2000, device=device, seed=seed + 999
    )

    success = rel_error < 5.0

    results = {
        "omega": omega,
        "omega_deg": np.degrees(omega),
        "bc_type": bc_type,
        "method": method,
        "seed": seed,
        "K": config.K,
        "total_steps": config.total_steps,
        "warmup_steps": config.warmup_steps,
        "ramp_steps": config.ramp_steps,
        "true_mu": float(true_mu),
        "mu_bounds": [model.mu_min, model.mu_max],
        "dominant_mu": float(dominant_mu),
        "dominant_coeff": float(dominant_coeff),
        "all_mus": final_mus.tolist(),
        "all_coeffs": final_coeffs.tolist(),
        "abs_error": abs_error,
        "rel_error": rel_error,
        "constraint_violation": constraint_viol,
        "l2_error": l2_error,
        "rel_l2_error": rel_l2_error,
        "success": success,
        "training_time": training_time,
        "final_loss": history["loss"][-1] if history["loss"] else float("nan"),
        "final_arc_loss": (
            history["arc_loss"][-1] if history["arc_loss"] else float("nan")
        ),
        "history": history,
    }

    return results


def main():
    parser = argparse.ArgumentParser(description="Train single MSN-PINN experiment")
    parser.add_argument(
        "--omega", type=float, required=True, help="Wedge angle in radians"
    )
    parser.add_argument(
        "--bc_type", type=str, required=True, choices=["DD", "NN", "DN", "ND"]
    )
    parser.add_argument(
        "--method", type=str, required=True, choices=["naive", "constraint"]
    )
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--gpu", type=int, default=None, help="GPU device ID")
    parser.add_argument("--output_dir", type=str, default="results")
    parser.add_argument("--quiet", action="store_true", help="Suppress verbose output")

    parser.add_argument("--K", type=int, default=6)
    parser.add_argument("--total_steps", type=int, default=20000)
    parser.add_argument("--warmup_steps", type=int, default=3000)
    parser.add_argument("--ramp_steps", type=int, default=5000)
    parser.add_argument("--lr_mu", type=float, default=5e-4)
    parser.add_argument("--w_constraint_max", type=float, default=100.0)

    args = parser.parse_args()

    device = get_device(args.gpu)
    print(f"Using device: {device}")

    config = TrainingConfig(
        K=args.K,
        total_steps=args.total_steps,
        warmup_steps=args.warmup_steps,
        ramp_steps=args.ramp_steps,
        lr_mu=args.lr_mu,
        w_constraint_max=args.w_constraint_max,
    )

    print(f"\n{'='*70}")
    print(
        f"Exp7: omega={args.omega:.4f} ({np.degrees(args.omega):.1f} deg), "
        f"BC={args.bc_type}, method={args.method}, seed={args.seed}"
    )
    print(
        f"Config: K={config.K}, steps={config.total_steps}, "
        f"warmup={config.warmup_steps}, ramp={config.ramp_steps}"
    )
    print(f"{'='*70}\n")

    results = train_single_experiment(
        omega=args.omega,
        bc_type=args.bc_type,
        method=args.method,
        seed=args.seed,
        config=config,
        device=device,
        verbose=not args.quiet,
    )

    results_to_save = {k: v for k, v in results.items() if k != "history"}

    def convert_to_serializable(obj):
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        elif isinstance(obj, (np.float32, np.float64)):
            return float(obj)
        elif isinstance(obj, (np.int32, np.int64)):
            return int(obj)
        elif isinstance(obj, (np.bool_,)):
            return bool(obj)
        elif isinstance(obj, dict):
            return {k: convert_to_serializable(v) for k, v in obj.items()}
        elif isinstance(obj, list):
            return [convert_to_serializable(i) for i in obj]
        return obj

    results_to_save = convert_to_serializable(results_to_save)

    os.makedirs(args.output_dir, exist_ok=True)
    omega_str = f"{args.omega:.4f}".replace(".", "p")
    filename = f"{omega_str}_{args.bc_type}_{args.method}_seed{args.seed}.json"
    output_path = os.path.join(args.output_dir, filename)

    with open(output_path, "w") as f:
        json.dump(results_to_save, f, indent=2)

    print(f"\n{'='*70}")
    print("RESULTS:")
    print(f"  True exponent: {results['true_mu']:.6f}")
    print(
        f"  Mu bounds: [{results['mu_bounds'][0]:.4f}, {results['mu_bounds'][1]:.4f}]"
    )
    print(f"  Predicted exponent: {results['dominant_mu']:.6f}")
    print(f"  Relative error: {results['rel_error']:.4f}%")
    print(f"  Constraint violation: {results['constraint_violation']:.6f}")
    print(f"  Solution L2 error: {results['rel_l2_error']:.4f}%")
    print(f"  Success: {results['success']}")
    print(f"  Training time: {results['training_time']:.1f}s")
    print(f"  Saved to: {output_path}")
    print(f"{'='*70}\n")

    return results


if __name__ == "__main__":
    main()
