"""
Student‑network training script (DDP) — **fast permutation lookup**
------------------------------------------------------------------
This version replaces the Python‑loop permutation generation in the
`HYDROGEN_DISPLACEMENTS` task with a vectorised 24‑entry lookup table, removing
batch‑prep bottlenecks.
"""

from __future__ import annotations

import argparse
import os
import shutil
import time
from typing import Generator, Tuple, Dict, Union, Iterable

import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import yaml
from itertools import permutations
from scipy.spatial.transform import Rotation  # random rotations
from torch.nn.parallel import DistributedDataParallel as DDP
from joblib import Parallel, delayed
from functools import lru_cache
from displacements_invariants import compute_invariants_displacements_wrapper

# ---------------------------------------------------------------------------
# Constant‑time permutation lookup (24 possible permutations of 4 elements)
# ---------------------------------------------------------------------------

# ---------------------------------------------------------------------------
# Model
# ---------------------------------------------------------------------------
class FeedForwardNet(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int, n_inner_layers: int):
        super().__init__()
        self.input_bn = nn.BatchNorm1d(input_dim)
        self.fcs, self.bns = nn.ModuleList(), nn.ModuleList()
        self.fcs.append(nn.Linear(input_dim, hidden_dim))
        self.bns.append(nn.BatchNorm1d(hidden_dim))
        for _ in range(n_inner_layers):
            self.fcs.append(nn.Linear(hidden_dim, hidden_dim))
            self.bns.append(nn.BatchNorm1d(hidden_dim))
        self.fcs.append(nn.Linear(hidden_dim, 1))

    def forward(self, x: torch.Tensor) -> torch.Tensor:  # (B, input_dim)
        x = self.input_bn(x)
        for i in range(len(self.fcs) - 1):
            x = F.relu(self.bns[i](self.fcs[i](x)))
        return self.fcs[-1](x).squeeze(-1)


# ---------------------------------------------------------------------------
# Mini‑batch generator / augmentation
# ---------------------------------------------------------------------------

def _make_rotations(num: int) -> np.ndarray:
    """Generate *num* random 3×3 rotation matrices (half improper), float32."""
    R = Rotation.random(num).as_matrix().astype(np.float32)  # (num, 3, 3)
    R[np.random.rand(num) >= 0.5] *= -1  # flip half to improper rotations
    return R

@lru_cache(maxsize=None)
def _perm_table(device: torch.device) -> torch.Tensor:
    return torch.tensor(list(permutations(range(4))),
                        dtype=torch.long, device=device)  # (24,4)

def iterate_minibatches(
    X: np.ndarray,
    y: np.ndarray,
    batch_size: int,
    shuffle: bool,
    task: str,
    mode: str,
    device: Union[str, torch.device] = "cpu",
) -> Generator[Tuple[torch.Tensor, torch.Tensor], None, None]:

    device = torch.device(device)

    # -------------------------------------------------------------------------
    # constant permutation table (24×4) – keep a single copy on the chosen device
    # -------------------------------------------------------------------------
    perm_tbl = _perm_table(device)

    # -------------------------------------------------------------------------
    # pre‑generate a pool of random proper/improper rotations, then move to device
    # -------------------------------------------------------------------------
    N_ROT_WORKERS = min(20, os.cpu_count() or 1)
    CHUNK_SIZE = 250_000

    n_samples = X.shape[0]
    order = np.arange(n_samples)
    if shuffle:
        np.random.shuffle(order)

    if (mode == "train") and (task in {"HYDROGEN_DISPLACEMENTS", "DISPLACEMENTS_INVARIANTS"}):
        n_chunks = (n_samples + CHUNK_SIZE - 1) // CHUNK_SIZE
        chunk_sizes: Iterable[int] = (
            [CHUNK_SIZE] * (n_chunks - 1) + [n_samples - CHUNK_SIZE * (n_chunks - 1)]
        )

        rot_pool_np = np.concatenate(
            Parallel(n_jobs=min(N_ROT_WORKERS, n_chunks), backend="loky")(
                delayed(_make_rotations)(m) for m in chunk_sizes
            ),
            axis=0,
        )  # (N, 3, 3)

        rot_pool = torch.tensor(rot_pool_np, device=device)  # (N, 3, 3)
        del rot_pool_np  # free host RAM
    else:
        rot_pool = None  # not used

    # -------------------------------------------------------------------------
    # iterate over batches
    # -------------------------------------------------------------------------
    for start in range(0, n_samples, batch_size):
        idx = order[start : start + batch_size]

        # NumPy → Torch, host → device
        xb = torch.as_tensor(X[idx], dtype=torch.float32, device=device)  # (B, …)
        yb = torch.as_tensor(y[idx], dtype=torch.float32, device=device)

        # ---------------------------------------------------------------------
        # task‑specific augmentation / selection
        # ---------------------------------------------------------------------
        if task == "INVARIANTS_DISTANCES":
            # nothing to do — tensors already on device
            pass

        elif task == "DISTANCES":
            if mode == "train":
                sel = torch.randint(0, 24, (xb.shape[0],), device=device)
                xb = xb[torch.arange(xb.shape[0], device=device), sel, :]  # (B, 10)
            else:
                xb = xb[:, 0, :]  # (B, 10)

        elif task == "HYDROGEN_DISPLACEMENTS":
            # xb: (B, 4, 3)
            if mode == "train":
                B = xb.shape[0]

                # --- 24‑way permutation ------------------------------------------------
                perm_idx = torch.randint(0, 24, (B,), device=device)  # (B,)
                xb = xb[
                    torch.arange(B, device=device).unsqueeze(1),  # (B, 1)
                    perm_tbl[perm_idx],  # (B, 4)
                    :,
                ]  # -> (B, 4, 3)

                # --- pre‑generated rotation -------------------------------------------
                idx_t = torch.as_tensor(idx, device=device)
                Rb = rot_pool[idx_t]  # (B, 3, 3)
                xb = torch.matmul(Rb, xb.transpose(1, 2)).transpose(1, 2)  # (B, 4, 3)

            # validation/test: no permutation, no rotation
            xb = xb.reshape(xb.shape[0], -1)  # (B, 12)
        elif task == "DISPLACEMENTS_INVARIANTS":
            # xb : (B, 4, 3) tensor already on device
            if mode == "train":
                # ----- random rotation only (no permutation) -------------------
                idx_t = torch.as_tensor(idx, device=device)
                Rb    = rot_pool[idx_t]                    # (B,3,3)
                xb    = torch.matmul(Rb, xb.transpose(1,2)).transpose(1,2)  # (B,4,3)

            # compute 34 invariant polynomials (GPU‑friendly)
            xb = compute_invariants_displacements_wrapper(xb)   # (B, 34)
        else:
            raise ValueError(f"Unsupported TASK '{task}'")

        yield xb, yb

# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def cpu_clone_state_dict(model: nn.Module) -> Dict[str, torch.Tensor]:
    return {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}


# --------------------------------------------------------------------------- #
# Training (per rank)
# --------------------------------------------------------------------------- #
def ddp_train(args: argparse.Namespace) -> None:
    rank       = int(os.environ.get("RANK", 0))
    world_size = int(os.environ.get("WORLD_SIZE", 1))
    dist.init_process_group("nccl", init_method="env://",
                            world_size=world_size, rank=rank)
    device = torch.device(f"cuda:{rank}")
    torch.cuda.set_device(device)

    # ---------------- config ----------------
    with open(args.calc_specification_path) as f:
        cfg = yaml.safe_load(f)

    TASK       = cfg.get("TASK", "INVARIANTS_DISTANCES")
    SUPPORTED  = {"INVARIANTS_DISTANCES", "DISTANCES", "HYDROGEN_DISPLACEMENTS", "DISPLACEMENTS_INVARIANTS"}
    if TASK not in SUPPORTED:
        raise ValueError(f"Unsupported TASK '{TASK}'. Choose from {SUPPORTED}")

    BASE_LR      = float(cfg["BASE_LR"])
    LR_DECAY_STEP= int(cfg["LEARNING_RATE_DECAY_SCALE"])
    NUM_EPOCHS   = int(cfg["NUM_EPOCHS"])
    BATCH_SIZE   = int(cfg["BATCH_SIZE"])
    HIDDEN_DIM   = int(cfg["HIDDEN_DIM"])
    N_INNER_LAYERS = int(cfg["N_INNER_LAYERS"])
    RECORD_INT   = int(cfg["RECORD_INTERVAL"])
    N_WARMUP = int(cfg.get("N_WARMUP", 0))  # Default to 0 warmup epochs

    # ---------------- dataset ---------------
    def _load(path: str) -> np.ndarray:
        if not os.path.exists(path):
            raise FileNotFoundError(path)
        return np.load(path)

    if TASK == "INVARIANTS_DISTANCES":
        X_train = _load("./dataset_preparation/methane_invariants_distances_train.npy")
        X_val   = _load("./dataset_preparation/methane_invariants_distances_val.npy")
        INPUT_DIM = 31
    elif TASK == "DISTANCES":
        X_train = _load("./dataset_preparation/methane_distances_train.npy")
        X_val   = _load("./dataset_preparation/methane_distances_val.npy")
        INPUT_DIM = 10
    elif TASK == "HYDROGEN_DISPLACEMENTS":
        X_train = _load("./dataset_preparation/methane_hydrogen_displacements_train.npy")
        X_val   = _load("./dataset_preparation/methane_hydrogen_displacements_val.npy")
        INPUT_DIM = 12
    elif TASK == "DISPLACEMENTS_INVARIANTS":
        X_train = _load("./dataset_preparation/methane_hydrogen_displacements_train.npy")
        X_val   = _load("./dataset_preparation/methane_hydrogen_displacements_val.npy")
        INPUT_DIM = 34
    else:
        raise ValueError(f"Unsupported TASK '{TASK}'.")

    y_train = _load("./dataset_preparation/methane_energies_train_normalized.npy")
    y_val   = _load("./dataset_preparation/methane_energies_val_normalized.npy")
    if X_train.shape[0] != y_train.shape[0] or X_val.shape[0] != y_val.shape[0]:
        raise ValueError("Feature/target sample mismatch.")

    X_train, X_val = X_train.astype(np.float32), X_val.astype(np.float32)
    y_train, y_val = y_train.astype(np.float32), y_val.astype(np.float32)

    # ---------------- model & optim ---------
    student   = FeedForwardNet(INPUT_DIM, HIDDEN_DIM, N_INNER_LAYERS).to(device)
    ddp_model = DDP(student, device_ids=[rank], output_device=rank)

    criterion  = nn.MSELoss()
    optimizer  = optim.Adam(ddp_model.parameters(), lr=BASE_LR)
    scheduler  = optim.lr_scheduler.StepLR(optimizer, step_size=LR_DECAY_STEP, gamma=0.5)

    # ---------------- training loop ---------
    best_val, best_val_rmse, best_epoch, best_state = (
        float("inf"), float("inf"), 0, None
    )
    hist = {
        "epoch": [],
        "lr": [],
        "train": [],
        "val": [],
        "train_rmse": [],
        "val_rmse": [],
    }
    t0 = time.time()

    for epoch in range(1, NUM_EPOCHS + 1):
        lr_now = optimizer.param_groups[0]["lr"]

        # ---- LR warmup ----
        if epoch <= N_WARMUP:
            warmup_lr = BASE_LR * (epoch / N_WARMUP)
            for param_group in optimizer.param_groups:
                param_group['lr'] = warmup_lr
            lr_now = warmup_lr  # Update lr_now for printing

        # ---- train ----
        ddp_model.train()
        sum_loss = torch.tensor(0.0, device=device)
        n_samp   = torch.tensor(0,   device=device)
        for xb_t, yb_t in iterate_minibatches(X_train, y_train, BATCH_SIZE, True, TASK, "train", device):
            optimizer.zero_grad(set_to_none=True)
            loss = criterion(ddp_model(xb_t), yb_t)
            loss.backward()
            optimizer.step()
            sum_loss += loss.detach() * xb_t.size(0)
            n_samp   += xb_t.size(0)
        dist.all_reduce(sum_loss); dist.all_reduce(n_samp)
        train_loss = (sum_loss / n_samp).item()
        train_rmse = train_loss**0.5

        # ---- val ----
        ddp_model.eval()
        with torch.no_grad():
            sum_val = torch.tensor(0.0, device=device)
            n_val   = torch.tensor(0,   device=device)
            for xb_t, yb_t in iterate_minibatches(X_val, y_val, BATCH_SIZE, False, TASK, "val", device):
                sum_val += criterion(ddp_model(xb_t), yb_t) * xb_t.size(0)
                n_val   += xb_t.size(0)
        dist.all_reduce(sum_val); dist.all_reduce(n_val)
        val_loss = (sum_val / n_val).item()
        val_rmse = val_loss**0.5

        # ---- checkpoints ----
        if rank == 0 and val_loss < best_val:
            best_val, best_val_rmse, best_epoch = val_loss, val_rmse, epoch
            best_state = cpu_clone_state_dict(ddp_model.module)

        # ---- Scheduler step (only after warmup) ----
        if epoch > N_WARMUP:
            scheduler.step()

        if rank == 0 and (epoch == 1 or epoch % RECORD_INT == 0 or epoch == NUM_EPOCHS):
            print(f"[Ep {epoch:03d}/{NUM_EPOCHS}] LR {lr_now:.2e} "
                  f"train {train_loss:.2e} ({train_rmse:.2e})  "
                  f"val {val_loss:.2e} ({val_rmse:.2e}) "
                  f"| {(time.time()-t0):.1f}s")

        hist["epoch"].append(epoch)
        hist["lr"].append(lr_now)
        hist["train"].append(train_loss)
        hist["val"].append(val_loss)
        hist["train_rmse"].append(train_rmse)
        hist["val_rmse"].append(val_rmse)

    # ---------------- save artefacts --------
    if rank == 0:
        os.makedirs(args.calc_folder, exist_ok=True)

        np.savez(os.path.join(args.calc_folder, "history.npz"), **hist)
        torch.save(ddp_model.module.state_dict(),
                   os.path.join(args.calc_folder, "model_weights_last_epoch.pth"))
        if best_state is not None:
            torch.save(best_state,
                       os.path.join(args.calc_folder, "model_weights_best_val.pth"))

        with open(os.path.join(args.calc_folder, "summary.txt"), "w") as f:
            f.write(f"TASK: {TASK}\n")
            f.write(f"final_val_loss: {val_loss:.4e}\n")
            f.write(f"final_val_rmse: {val_rmse:.4e}\n")
            f.write(f"best_val_loss:  {best_val:.4e}  (epoch {best_epoch})\n")
            f.write(f"best_val_rmse:  {best_val_rmse:.4e}  (epoch {best_epoch})\n")
            f.write(f"wall_time_s:    {time.time()-t0:.1f}\n")
            f.write(f"num_gpus:       {world_size}\n")
            f.write(f"effect_batch:   {BATCH_SIZE * world_size}\n")

        shutil.copy(args.calc_specification_path,
                    os.path.join(args.calc_folder,
                                 os.path.basename(args.calc_specification_path)))

    dist.destroy_process_group()


# --------------------------------------------------------------------------- #
# CLI
# --------------------------------------------------------------------------- #
def main() -> None:
    p = argparse.ArgumentParser(description="Train student network (DDP, methane dataset)")
    p.add_argument("--calc_specification_path", required=True)
    p.add_argument("--calc_folder", required=True)
    ddp_train(p.parse_args())


if __name__ == "__main__":
    main()
