"""
train_autoencoder.py

inputs: train/val/test files
    fc_mats    : (S, N, N) float
    dm_features_norm: (S, N, F) float

edge_self_atten_encoder.py     providing GraphEncoder
memory_xatten_decoder.py       GraphDecoder

dependencies:
  - numpy, torch
  - h5py

exp uage  python train_autoencoder.py \
      --train_path path/to/train.h5 --val_path path/to/val.h5 --test_path path/to/test.h5 \
      --out_dir runs/rest_run1
"""

from __future__ import annotations

import argparse
import json
import os
from dataclasses import asdict, dataclass
from typing import Any, Dict, Optional, Tuple

import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset


# data loading

def _load_arrays(path: str, fc_key: str, x_key: str) -> Tuple[np.ndarray, np.ndarray]:
    """
    load (fc, X) arrays from a file
    """
    if not os.path.exists(path):
        raise FileNotFoundError(f"File not found: {path}")

    # h5/hdf5
    try:
        import h5py
    except Exception as e:
        raise ImportError("h5py is required to read .h5/.hdf5 files.") from e

    with h5py.File(path, "r") as f:
        if fc_key not in f:
            raise KeyError(f"'{fc_key}' not found in {path}. Keys: {list(f.keys())}")
        fc = f[fc_key][:]

        if x_key not in f:
            raise KeyError(f"Could not find node features key '{x_key}' in {path}. Keys: {list(f.keys())}")
        x = f[x_key][:]

    return fc, x


class GraphDataset(Dataset):
    def __init__(self, fc_mats: np.ndarray, node_features: np.ndarray):
        fc = np.asarray(fc_mats, dtype=np.float32)
        x = np.asarray(node_features, dtype=np.float32)

        if fc.ndim != 3:
            raise ValueError(f"fc_mats must be (S,N,N); got {fc.shape}")
        if x.ndim != 3:
            raise ValueError(f"node_features must be (S,N,F); got {x.shape}")
        if fc.shape[0] != x.shape[0] or fc.shape[1] != x.shape[1]:
            raise ValueError(f"Mismatch: fc={fc.shape}, x={x.shape}")

        self.A = torch.from_numpy(fc)  # [S,N,N]
        self.X = torch.from_numpy(x)   # [S,N,F]

    def __len__(self) -> int:
        return int(self.A.shape[0])

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        return self.A[idx], self.X[idx]


# training
@dataclass
class TrainConfig:
    # io
    train_path: str
    val_path: str
    test_path: str
    out_dir: str
    fc_key: str = "fc_mats"
    x_key: str = "dm_features_norm"

    # optimization
    epochs: int = 150
    batch_size: int = 8
    lr: float = 2e-3
    betas0: float = 0.9
    betas1: float = 0.95
    early_stop_patience: int = 30

    # scheduler
    sched_factor: float = 0.5
    sched_patience: int = 10
    sched_threshold: float = 1e-4
    min_lr: float = 1e-5

    # device
    device: str = "cuda"

    # encoder (rest)
    enc_d_h: int = 48
    enc_d_e: int = 2
    enc_heads: int = 4
    enc_d_ff: int = 64
    enc_d_ff_edge: int = 16
    enc_layers: int = 4
    d_g: int = 16
    dropout_p: float = 0.2

    # decoder
    dec_layers: int = 2
    dec_heads: int = 4
    dec_d_ff: int = 128
    dec_d_m: int = 32
    dec_d_h: int = 32
    dec_d_r: int = 32


def _select_device(device_pref: str) -> torch.device:
    if device_pref.lower().startswith("cuda") and torch.cuda.is_available():
        return torch.device("cuda")
    return torch.device("cpu")


def _import_models():
    """
    import GraphEncoder and GraphDecoder
    """
    from edge_self_atten_encoder import GraphEncoder
    from memory_xatten_decoder import GraphDecoder

    return GraphEncoder, GraphDecoder


@torch.no_grad()
def evaluate(loader: DataLoader, encoder, decoder, device: torch.device) -> Dict[str, float]:
    encoder.eval()
    decoder.eval()
    total_loss = 0.0
    n_total = 0

    for A, X in loader:
        A = A.to(device)
        X = X.to(device)

        enc_out = encoder(A, X)
        dec_out = decoder(enc_out.z_g)
        loss_b = decoder.edge_mse(A, dec_out.A_hat)  # [B]
        total_loss += float(loss_b.sum().item())
        n_total += int(A.size(0))

    return {"edge_mse": total_loss / max(n_total, 1)}


def train_one_epoch(loader: DataLoader, encoder, decoder, optimizer, device: torch.device) -> Dict[str, float]:
    encoder.train()
    decoder.train()
    total_loss = 0.0
    n_total = 0

    for A, X in loader:
        A = A.to(device)
        X = X.to(device)

        enc_out = encoder(A, X)          # deterministic by default
        dec_out = decoder(enc_out.z_g)   # edge recon only by default
        loss_b = decoder.edge_mse(A, dec_out.A_hat)  # [B]
        loss = loss_b.mean()

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

        total_loss += float(loss_b.sum().item())
        n_total += int(A.size(0))

    return {"edge_mse": total_loss / max(n_total, 1)}


def main(cfg: TrainConfig) -> None:
    os.makedirs(cfg.out_dir, exist_ok=True)
    device = _select_device(cfg.device)

    # load data
    fc_tr, x_tr = _load_arrays(cfg.train_path, cfg.fc_key, cfg.x_key)
    fc_va, x_va = _load_arrays(cfg.val_path, cfg.fc_key, cfg.x_key)
    fc_te, x_te = _load_arrays(cfg.test_path, cfg.fc_key, cfg.x_key)

    # build datasets / loaders
    ds_tr = GraphDataset(fc_tr, x_tr)
    ds_va = GraphDataset(fc_va, x_va)
    ds_te = GraphDataset(fc_te, x_te)

    dl_tr = DataLoader(ds_tr, batch_size=cfg.batch_size, shuffle=True, drop_last=True)
    dl_va = DataLoader(ds_va, batch_size=cfg.batch_size, shuffle=False)
    dl_te = DataLoader(ds_te, batch_size=cfg.batch_size, shuffle=False)

    input_dim = int(x_tr.shape[2])
    num_nodes = int(fc_tr.shape[1])

    # import models
    GraphEncoder, GraphDecoder = _import_models()

    # instantiate encoder (deterministic default; no variational)
    encoder = GraphEncoder(
        in_dim=input_dim,
        d_h=cfg.enc_d_h,
        d_e=cfg.enc_d_e,
        num_heads=cfg.enc_heads,
        d_ff=cfg.enc_d_ff,
        d_ff_edge=cfg.enc_d_ff_edge,
        num_layers=cfg.enc_layers,
        d_g=cfg.d_g,
        dropout_p=cfg.dropout_p,
        variational=False,
    ).to(device)

    # instantiate decoder (edge recon only)
    decoder = GraphDecoder(
        num_nodes=num_nodes,
        node_feat_dim=input_dim,
        d_g=cfg.d_g,
        d_m=cfg.dec_d_m,
        d_h=cfg.dec_d_h,
        d_r=cfg.dec_d_r,
        num_heads=cfg.dec_heads,
        num_layers=cfg.dec_layers,
        d_ff=cfg.dec_d_ff,
        dropout_p=cfg.dropout_p,
        reconstruct_nodes=False,
    ).to(device)

    optimizer = torch.optim.Adam(
        list(encoder.parameters()) + list(decoder.parameters()),
        lr=cfg.lr,
        betas=(cfg.betas0, cfg.betas1),
    )
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode="min",
        factor=cfg.sched_factor,
        patience=cfg.sched_patience,
        threshold=cfg.sched_threshold,
        min_lr=cfg.min_lr,
    )

    with open(os.path.join(cfg.out_dir, "config.json"), "w", encoding="utf-8") as f:
        json.dump(asdict(cfg), f, indent=2)

    best_val = float("inf")
    best_path = os.path.join(cfg.out_dir, "best_checkpoint.pt")
    patience = 0
    history = {"train_edge_mse": [], "val_edge_mse": [], "lr": []}

    for epoch in range(1, cfg.epochs + 1):
        tr = train_one_epoch(dl_tr, encoder, decoder, optimizer, device)
        va = evaluate(dl_va, encoder, decoder, device)

        scheduler.step(va["edge_mse"])
        lr = float(optimizer.param_groups[0]["lr"])

        history["train_edge_mse"].append(tr["edge_mse"])
        history["val_edge_mse"].append(va["edge_mse"])
        history["lr"].append(lr)

        if epoch == 1 or epoch % 10 == 0:
            print(
                f"Epoch {epoch:03d} | "
                f"train edge MSE={tr['edge_mse']:.6f} | "
                f"val edge MSE={va['edge_mse']:.6f} | "
                f"lr={lr:.2e}"
            )

        # early stopping on val edge MSE
        if va["edge_mse"] < best_val:
            best_val = va["edge_mse"]
            patience = 0
            torch.save(
                {
                    "enc_state": encoder.state_dict(),
                    "dec_state": decoder.state_dict(),
                    "best_val_edge_mse": best_val,
                    "input_dim": input_dim,
                    "num_nodes": num_nodes,
                },
                best_path,
            )
        else:
            patience += 1
            if patience >= cfg.early_stop_patience:
                print(f"Early stopping at epoch {epoch} (best val edge MSE={best_val:.6f})")
                break

    # load best and test
    checkpoint = torch.load(best_path, map_location=device)
    encoder.load_state_dict(checkpoint["enc_state"])
    decoder.load_state_dict(checkpoint["dec_state"])

    te = evaluate(dl_te, encoder, decoder, device)
    print(f"Test edge MSE: {te['edge_mse']:.6f}")

    # save history and test
    with open(os.path.join(cfg.out_dir, "history.json"), "w", encoding="utf-8") as f:
        json.dump(history, f, indent=2)
    with open(os.path.join(cfg.out_dir, "test_metrics.json"), "w", encoding="utf-8") as f:
        json.dump(te, f, indent=2)


def parse_args() -> TrainConfig:
    p = argparse.ArgumentParser(description="Train resting-state connectome autoencoder (edge MSE).")
    p.add_argument("--train_path", type=str, required=True)
    p.add_argument("--val_path", type=str, required=True)
    p.add_argument("--test_path", type=str, required=True)
    p.add_argument("--out_dir", type=str, required=True)

    p.add_argument("--fc_key", type=str, default="fc_mats")
    p.add_argument("--x_key", type=str, default="dm_features_norm")

    p.add_argument("--batch_size", type=int, default=8)
    p.add_argument("--epochs", type=int, default=150)
    p.add_argument("--lr", type=float, default=2e-3)
    p.add_argument("--early_stop_patience", type=int, default=30)
    p.add_argument("--device", type=str, default="cuda")

    args = p.parse_args()

    cfg = TrainConfig(
        train_path=args.train_path,
        val_path=args.val_path,
        test_path=args.test_path,
        out_dir=args.out_dir,
        fc_key=args.fc_key,
        x_key=args.x_key,
        batch_size=args.batch_size,
        epochs=args.epochs,
        lr=args.lr,
        early_stop_patience=args.early_stop_patience,
        device=args.device,
    )
    return cfg

if __name__ == "__main__":
    cfg = parse_args()
    main(cfg)