#!/usr/bin/env python3
import os
import math
import argparse
from typing import Tuple, List, Optional

import numpy as np
import torch
from torch import nn, optim
from torch.utils.data import TensorDataset, DataLoader
from tqdm import trange

# project-local imports
from problem import ackley
from model import FCNN

# ----------------------------
# Data generation / sampling
# ----------------------------

@torch.no_grad()
def sample_ackley_full_fcoverage(
    n_dim: int = 10,
    N_plateau: int = 300_000,   # uniform cube
    N_shell: int = 600_000,     # total over all f-shells
    N_center: int = 100_000,    # extra around f≈0
    f_max: float = 10.0,        # how deep you want your “funnel” coverage
    M_shells: int = 50,         # how many f-slices
    r_center: float = 0.05,     # radius of the center ball
    device: str = "cpu",
) -> torch.Tensor:
    """
    Returns X: [N_plateau + N_shell + N_center, n_dim]
    covering:
      - plateau uniform in [-10,10]^n
      - M radial shells chosen to hit f values in [0,f_max]
      - tiny ball around origin for extra f≈0 points
    """
    # 1) plateau
    X_plateau = torch.empty(N_plateau, n_dim, device=device).uniform_(-10, 10)

    # --- precompute the envelope constants ---
    a = 20.0
    b = 0.2
    C = np.e - 1.0   # e - 1

    # 2) stratified f-shells
    per_shell = N_shell // M_shells
    remainder = N_shell - per_shell * M_shells
    # distribute remainder: first 'remainder' shells get +1
    counts = [per_shell + (1 if i < remainder else 0) for i in range(M_shells)]

    f_bins = torch.linspace(0.0, f_max, M_shells + 1, device=device)
    shells: List[torch.Tensor] = []
    for i, cnt in enumerate(counts):
        if cnt == 0:
            continue
        f_lo, f_hi = f_bins[i], f_bins[i + 1]
        f_t = torch.rand(cnt, device=device) * (f_hi - f_lo) + f_lo  # [cnt]

        # invert the radial envelope (ignoring cos ripples):
        #   f_env(r) = C + a*(1 - exp(-b * r/√n))
        #   => 1 - (f_env - C)/a = exp(-b r/√n)
        #   => r = - (√n / b) * ln( 1 - (f_t - C)/a )
        inside = 1 - (f_t - C) / a
        inside = inside.clamp(min=1e-6, max=1.0)  # avoid log issues
        r = - (math.sqrt(n_dim) / b) * torch.log(inside)  # [cnt]

        # random directions
        dirs = torch.randn(cnt, n_dim, device=device)
        dirs = dirs / dirs.norm(dim=1, keepdim=True).clamp_min(1e-12)

        shells.append(dirs * r.unsqueeze(1))

    X_shell = torch.cat(shells, dim=0) if shells else torch.empty(0, n_dim, device=device)

    # 3) small ball around origin to flood f≈0
    dirs0 = torch.randn(N_center, n_dim, device=device)
    dirs0 = dirs0 / dirs0.norm(dim=1, keepdim=True).clamp_min(1e-12)
    u0 = torch.rand(N_center, device=device)
    r0 = r_center * u0.pow(1.0 / n_dim)
    X_center = dirs0 * r0.unsqueeze(1)

    # combine
    X = torch.cat([X_plateau, X_shell, X_center], dim=0)
    return X


def make_dataloaders(
    X: torch.Tensor,
    y: torch.Tensor,
    batch_size: int = 1024,
    val_split: float = 0.2,
    num_workers: int = 0,
    device: str = "cpu",
) -> Tuple[DataLoader, DataLoader, torch.Tensor, torch.Tensor]:
    """Standardize X based on train split and build train/val loaders."""
    N = X.shape[0]
    perm = torch.randperm(N, device=X.device)
    split = int(N * (1 - val_split))
    train_idx, val_idx = perm[:split], perm[split:]

    X_train, y_train = X[train_idx], y[train_idx]
    X_val, y_val = X[val_idx], y[val_idx]

    mean = X_train.mean(dim=0, keepdim=True)
    std = X_train.std(dim=0, unbiased=False, keepdim=True).clamp_min(1e-12)

    X_train_std = (X_train - mean) / std
    X_val_std = (X_val - mean) / std

    train_ds = TensorDataset(X_train_std.cpu(), y_train.cpu())
    val_ds = TensorDataset(X_val_std.cpu(), y_val.cpu())

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    return train_loader, val_loader, mean.cpu(), std.cpu()


# ----------------------------
# Training utilities
# ----------------------------

def train_one_epoch(
    model: nn.Module,
    loader: DataLoader,
    optimizer: optim.Optimizer,
    criterion: nn.Module,
    device: torch.device,
) -> float:
    model.train()
    total_loss = 0.0
    seen = 0
    for xb, yb in loader:
        xb = xb.to(device, non_blocking=True)
        yb = yb.to(device, non_blocking=True)
        pred = model(xb)
        loss = criterion(pred, yb)
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
        bs = xb.size(0)
        total_loss += loss.item() * bs
        seen += bs
    return total_loss / max(seen, 1)


@torch.no_grad()
def evaluate(
    model: nn.Module,
    loader: DataLoader,
    criterion: nn.Module,
    device: torch.device,
) -> float:
    model.eval()
    total_loss = 0.0
    seen = 0
    for xb, yb in loader:
        xb = xb.to(device, non_blocking=True)
        yb = yb.to(device, non_blocking=True)
        pred = model(xb)
        loss = criterion(pred, yb)
        bs = xb.size(0)
        total_loss += loss.item() * bs
        seen += bs
    return total_loss / max(seen, 1)


def save_checkpoint(
    ckpt_dir: str,
    epoch: int,
    model: nn.Module,
    optimizer: optim.Optimizer,
    train_loss: float,
    val_loss: float,
    train_losses: List[float],
    val_losses: List[float],
):
    os.makedirs(ckpt_dir, exist_ok=True)
    ckpt_path = os.path.join(ckpt_dir, f"ckpt_epoch_{epoch}.pt")
    torch.save(
        {
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "train_loss": train_loss,
            "val_loss": val_loss,
            "train_losses": train_losses,
            "val_losses": val_losses,
        },
        ckpt_path,
    )


def plot_losses(train_losses: List[float], val_losses: List[float], out_path: str):
    import matplotlib.pyplot as plt
    plt.figure()
    plt.plot(train_losses, label="Train Loss")
    plt.plot(val_losses, label="Validation Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Training and Validation Loss Curves")
    plt.legend()
    plt.tight_layout()
    plt.savefig(out_path)
    plt.close()


# ----------------------------
# Main training flow
# ----------------------------

def run(args: argparse.Namespace):
    # device
    device = torch.device(args.device if args.device != "auto" else ("cuda" if torch.cuda.is_available() else "cpu"))
    print(f"Using device: {device}")

    # Reproducibility-ish
    if args.seed is not None:
        torch.manual_seed(args.seed)
        np.random.seed(args.seed)

    # 1) Sample inputs X with full f-coverage
    X = sample_ackley_full_fcoverage(
        n_dim=args.n_dim,
        N_plateau=args.n_plateau,
        N_shell=args.n_shell,
        N_center=args.n_center,
        f_max=args.f_max,
        M_shells=args.m_shells,
        r_center=args.r_center,
        device=str(device),
    )

    # 2) Compute targets from Ackley and transform
    ack = ackley(n_dim=args.n_dim)
    f_vals = ack.f_func(X)                   # [N]
    y = torch.exp(-f_vals)                   # [N]
    y_trans = y.pow(args.target_power).unsqueeze(1) * args.target_scale  # [N,1]

    # 3) Dataloaders + standardization stats
    train_loader, val_loader, mean, std = make_dataloaders(
        X=X,
        y=y_trans,
        batch_size=args.batch_size,
        val_split=args.val_split,
        num_workers=args.num_workers,
        device=str(device),
    )

    # 4) Model / Optim / Loss
    model = FCNN(n_dim=args.n_dim).to(device)
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    criterion = nn.MSELoss()

    # 5) Save stats for later inference
    os.makedirs(args.ckpt_dir, exist_ok=True)
    torch.save({"mean": mean, "std": std}, os.path.join(args.ckpt_dir, "stats.pt"))

    # 6) Train loop with early stopping and periodic checkpoints
    patience = args.patience
    best_val = float("inf")
    epochs_no_improve = 0
    train_losses: List[float] = []
    val_losses: List[float] = []

    bar = trange(1, args.max_epochs + 1, desc="Epochs")
    for epoch in bar:
        train_loss = train_one_epoch(model, train_loader, optimizer, criterion, device)
        val_loss = evaluate(model, val_loader, criterion, device)
        train_losses.append(train_loss)
        val_losses.append(val_loss)

        if args.save_every > 0 and (epoch % args.save_every == 0 or epoch == args.max_epochs):
            save_checkpoint(args.ckpt_dir, epoch, model, optimizer, train_loss, val_loss, train_losses, val_losses)

        bar.set_postfix(train=f"{train_loss:.4e}", val=f"{val_loss:.4e}")

        if val_loss < best_val - args.min_delta:
            best_val = val_loss
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1

        if epochs_no_improve >= patience:
            print(f"Early stopping at epoch {epoch} (no improvement for {patience} epochs).")
            break

    # Save final model + loss curves
    final_model_path = os.path.join(args.ckpt_dir, "final_model.pt")
    torch.save(model.state_dict(), final_model_path)
    print(f"Saved final model to {final_model_path}")

    if args.plot:
        loss_curve_path = os.path.join(args.ckpt_dir, "loss_curves.png")
        plot_losses(train_losses, val_losses, loss_curve_path)
        print(f"Saved loss curves to {loss_curve_path}")


# ----------------------------
# Argparse
# ----------------------------

def build_argparser() -> argparse.ArgumentParser:
    p = argparse.ArgumentParser(description="Train FCNN on Ackley data with stratified f-coverage sampling.")
    # Data / problem
    p.add_argument("--n-dim", type=int, default=10, dest="n_dim", help="Dimensionality of input space.")
    p.add_argument("--n-plateau", type=int, default=1_800_000, dest="n_plateau", help="Uniform samples in [-10,10]^n.")
    p.add_argument("--n-shell", type=int, default=1_800_000, dest="n_shell", help="Total samples across f-shells.")
    p.add_argument("--n-center", type=int, default=1_000, dest="n_center", help="Samples in small central ball.")
    p.add_argument("--f-max", type=float, default=10.0, dest="f_max", help="Max target f value for shell sampling.")
    p.add_argument("--m-shells", type=int, default=50, dest="m_shells", help="Number of f-slices / shells.")
    p.add_argument("--r-center", type=float, default=0.05, dest="r_center", help="Radius of central ball.")
    p.add_argument("--val-split", type=float, default=0.2, dest="val_split", help="Validation split fraction.")
    p.add_argument("--target-power", type=float, default=0.1, dest="target_power", help="Exponent on exp(-f).")
    p.add_argument("--target-scale", type=float, default=5.0, dest="target_scale", help="Scale on transformed target.")

    # Training
    p.add_argument("--batch-size", type=int, default=1024, dest="batch_size", help="Batch size.")
    p.add_argument("--lr", type=float, default=1e-3, help="Learning rate.")
    p.add_argument("--patience", type=int, default=1000, help="Early stopping patience.")
    p.add_argument("--min-delta", type=float, default=0.0, dest="min_delta", help="Minimum improvement to reset patience.")
    p.add_argument("--max-epochs", type=int, default=1_000_000, dest="max_epochs", help="Maximum epochs.")
    p.add_argument("--num-workers", type=int, default=0, dest="num_workers", help="DataLoader workers.")

    # System / IO
    p.add_argument("--device", type=str, default="auto", choices=["auto", "cpu", "cuda"], help="Compute device.")
    p.add_argument("--ckpt-dir", type=str, default="tanh_model", dest="ckpt_dir", help="Directory to save outputs.")
    p.add_argument("--save-every", type=int, default=10, dest="save_every", help="Checkpoint interval (epochs). 0=off.")
    p.add_argument("--plot", action="store_true", help="Save loss curve PNG.")
    p.add_argument("--seed", type=int, default=None, help="Random seed (int).")
    return p


def main():
    args = build_argparser().parse_args()
    run(args)


if __name__ == "__main__":
    main()