import os
import json
import time
from datetime import datetime
from typing import Optional

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

from ..utils import set_seed as _set_seed, stratified_pick_per_class as _pick_per_class
from ..augment import geom_module as _geom_module, opt_module as _opt_module
import numpy as _np
from generative.diffusion.diffusion_train import Conditional_UNet


def build_mnist_dataloaders(
    batch_size: int,
    num_workers: int,
    image_size: int,
    n_per_class: int,
    seed: int,
) -> DataLoader:
    transform = transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5,), std=(0.5,)),  # [-1, 1]
    ])
    train_ds = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
    targets = train_ds.targets.detach().cpu().numpy().astype("int64")
    if _pick_per_class is not None:
        picked_idx = _pick_per_class(
            targets,
            per_class=n_per_class,
            num_classes=10,
            rng=_np.random.RandomState(seed),
        )
    else:
        rng = _np.random.RandomState(seed)
        picked_idx_list = []
        for c in range(10):
            cls_idx = _np.where(targets == c)[0]
            if len(cls_idx) <= n_per_class:
                picked = cls_idx
            else:
                picked = rng.choice(cls_idx, size=n_per_class, replace=False)
            picked_idx_list.append(picked)
        picked_idx = _np.concatenate(picked_idx_list, axis=0)
        rng.shuffle(picked_idx)
    from torch.utils.data import Subset
    subset = Subset(train_ds, picked_idx.tolist())
    train_dl = DataLoader(subset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=num_workers, pin_memory=True)
    return train_dl


def build_aug_only_loader(
    batch_size: int,
    num_workers: int,
    image_size: int,
    n_per_class: int,
    seed: int,
    aug_types: str = "geom,opt",
) -> DataLoader:
    if _geom_module is None or _opt_module is None:
        return build_mnist_dataloaders(batch_size, num_workers, image_size, n_per_class, seed)

    from torchvision.datasets import MNIST as _MNIST
    ds = _MNIST(root="./data", train=True, download=True)
    X_all = ds.data.numpy().astype("uint8")
    y_all = ds.targets.numpy().astype("int64")

    if _pick_per_class is not None:
        picked_idx = _pick_per_class(
            y_all,
            per_class=n_per_class,
            num_classes=10,
            rng=_np.random.RandomState(seed),
        )
    else:
        rngp = _np.random.RandomState(seed)
        idx_list = []
        for c in range(10):
            cls_idx = _np.where(y_all == c)[0]
            if len(cls_idx) <= n_per_class:
                picked = cls_idx
            else:
                picked = rngp.choice(cls_idx, size=n_per_class, replace=False)
            idx_list.append(picked)
        picked_idx = _np.concatenate(idx_list, axis=0)
        rngp.shuffle(picked_idx)

    X_sel = X_all[picked_idx]
    y_sel = y_all[picked_idx]

    use_geom = "geom" in (aug_types or "").split(",")
    use_opt = "opt" in (aug_types or "").split(",")
    if not (use_geom or use_opt):
        use_geom = True

    if use_geom and use_opt:
        Xg, _ = _geom_module(X_sel, y_sel, per_sample=1)
        Xgo, _ = _opt_module(Xg, y_sel, per_sample=1)
        X_aug = Xgo
        y_aug = y_sel
    elif use_geom:
        Xg, _ = _geom_module(X_sel, y_sel, per_sample=1)
        X_aug = Xg
        y_aug = y_sel
    elif use_opt:
        Xo, _ = _opt_module(X_sel, y_sel, per_sample=1)
        X_aug = Xo
        y_aug = y_sel
    else:
        X_aug = X_sel
        y_aug = y_sel

    rng = _np.random.RandomState(seed + 9)
    perm = rng.permutation(len(X_aug))
    X_aug = X_aug[perm]
    y_aug = y_aug[perm]

    import torch as _torch
    from torch.utils.data import TensorDataset as _TensorDataset
    X_t = _torch.from_numpy(X_aug).float().unsqueeze(1) / 127.5 - 1.0
    if image_size != 28:
        import torch.nn.functional as _F
        X_t = _F.interpolate(X_t, size=(image_size, image_size), mode="bilinear", align_corners=False)
    y_t = _torch.from_numpy(y_aug).long()
    ds_t = _TensorDataset(X_t, y_t)
    return DataLoader(ds_t, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=num_workers, pin_memory=True)


def train_rectified_flow(
    save_root: str,
    batch_size: int = 64,
    epochs: int = 30,
    lr: float = 1e-4,
    image_size: int = 32,
    n_per_class: int = 5000,
    seed: int = 42,
    num_workers: int = 8,
    device_str: Optional[str] = None,
    time_scale: float = 1000.0,
    gen_variant: str = "orig",
    aug_types: str = "geom,opt",
    dataset: str = "mnist",
) -> None:
    if _set_seed is not None:
        _set_seed(seed)
    else:
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

    device = torch.device(device_str) if device_str else torch.device("cuda" if torch.cuda.is_available() else "cpu")

    variant = (gen_variant or "orig").lower()
    if dataset == "mnist":
        channels = 1
        base_width = 64
        if variant == "orig":
            dl = build_mnist_dataloaders(
                batch_size=batch_size,
                num_workers=num_workers,
                image_size=image_size,
                n_per_class=n_per_class,
                seed=seed,
            )
        else:
            dl = build_aug_only_loader(
                batch_size=batch_size,
                num_workers=num_workers,
                image_size=image_size,
                n_per_class=n_per_class,
                seed=seed,
                aug_types=aug_types,
            )
    else:
        channels = 3
        base_width = 96
        from generative.diffusion.diffusion_train import build_cifar_dataloader
        dl = build_cifar_dataloader(
            batch_size=batch_size,
            num_workers=num_workers,
            image_size=image_size,
            n_per_class=n_per_class,
            seed=seed,
            gen_variant=variant,
            aug_types=aug_types,
        )

    if dataset == "cifar10":
        base_width = 128
        model = Conditional_UNet(
            t_emb_dim=256,
            num_classes=10,
            device_str=str(device),
            in_ch=channels,
            out_ch=channels,
            base_width=base_width,
            enable_sa1=True,
            enable_sa2=True,
            enable_sa3=False,
            enable_sa4=True,
            enable_sa5=True,
            enable_sa6=False,
        ).to(device)
    else:
        model = Conditional_UNet(t_emb_dim=256, num_classes=10, device_str=str(device), in_ch=channels, out_ch=channels, base_width=base_width).to(device)
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.MSELoss()

    if dataset == "cifar10" and epochs == 20:
        epochs = 200
    if dataset == "cifar10" and epochs == 30:
        epochs = 200

    ds_name = dataset if dataset in ("mnist", "cifar10") else ("mnist" if channels == 1 else "cifar10")
    params_root = os.path.join(save_root, "params", ds_name, variant)
    os.makedirs(params_root, exist_ok=True)

    cfg = {
        "batch_size": batch_size,
        "epochs": epochs,
        "lr": lr,
        "image_size": image_size,
        "n_per_class": n_per_class,
        "seed": seed,
        "device": str(device),
        "num_classes": 10,
        "t_emb_dim": 256,
        "arch": "Conditional_UNet(attn, 32x32, base64-256)",
        "training": "rectified_flow_linear_path",
        "time_scale": time_scale,
        "gen_variant": variant,
        "dataset": dataset,
        "channels": channels,
        "aug_types": aug_types,
        "base_width": base_width,
    }
    print(f"Start Rectified-Flow training (notebook-UNet backbone), epochs={epochs}", flush=True)
    for ep in range(1, epochs + 1):
        model.train()
        t0 = time.time()
        running_loss = 0.0
        num_batches = 0

        for xb, yb in dl:
            xb = xb.to(device, non_blocking=True)
            yb = yb.to(device, non_blocking=True)
            b = xb.size(0)

            x0 = torch.randn_like(xb)
            t = torch.rand(b, device=device)
            x_t = (1.0 - t).view(b, 1, 1, 1) * x0 + t.view(b, 1, 1, 1) * xb
            v_target = xb - x0

            t_scaled = t * float(time_scale)

            v_pred = model(x_t, t_scaled, yb)
            loss = loss_fn(v_pred, v_target)

            opt.zero_grad(set_to_none=True)
            loss.backward()
            opt.step()

            running_loss += loss.item()
            num_batches += 1

        dt = time.time() - t0
        avg_loss = running_loss / max(1, num_batches)
        print(f"[{ep:03d}/{epochs}] time={dt:.1f}s loss={avg_loss:.4f}", flush=True)

    ckpt_path = os.path.join(params_root, "checkpoint.pth")
    torch.save(
        {
            "model": model.state_dict(),
            "config": cfg,
        },
        ckpt_path,
    )
    print(f"Saved FM checkpoint: {ckpt_path}", flush=True)


def main():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--save_root", type=str, default=os.path.dirname(__file__))
    parser.add_argument("--batch_size", type=int, default=64)
    parser.add_argument("--epochs", type=int, default=30)
    parser.add_argument("--lr", type=float, default=1e-4)
    parser.add_argument("--image_size", type=int, default=32)
    parser.add_argument("--n_per_class", type=int, default=5000)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--num_workers", type=int, default=8)
    parser.add_argument("--device", type=str, default="cuda:0")
    parser.add_argument("--time_scale", type=float, default=1000.0)
    parser.add_argument("--gen_variant", type=str, default="orig", choices=["orig", "aug"])
    parser.add_argument("--aug_types", type=str, default="geom,opt")
    parser.add_argument("--dataset", type=str, default="mnist", choices=["mnist","cifar10"])
    args = parser.parse_args()

    train_rectified_flow(
        save_root=args.save_root,
        batch_size=args.batch_size,
        epochs=args.epochs,
        lr=args.lr,
        image_size=args.image_size,
        n_per_class=args.n_per_class,
        seed=args.seed,
        num_workers=args.num_workers,
        device_str=args.device,
        time_scale=args.time_scale,
        gen_variant=args.gen_variant,
        aug_types=args.aug_types,
        dataset=args.dataset,
    )


if __name__ == "__main__":
    main()


