import os
import math
import json
import time
from datetime import datetime
from typing import Tuple, Optional
import argparse

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from torchvision.datasets import MNIST, CIFAR10

from ..augment import geom_module, opt_module
from ..utils import stratified_pick_per_class, set_seed


def build_conditional_augmented_mnist(
    n_per_class: int = 2000,
    aug_ratio: float = 1.0,
    aug_types: str = "geom,opt",
    seed: int = 42,
) -> Tuple[torch.Tensor, torch.Tensor]:
    set_seed(seed)
    ds = MNIST(root="./data", train=True, download=True)
    X_all = ds.data.numpy().astype(np.uint8)
    y_all = ds.targets.numpy().astype(np.int64)

    picked = stratified_pick_per_class(
        y_all,
        per_class=n_per_class,
        num_classes=10,
        rng=np.random.RandomState(seed),
    )
    X_sel = X_all[picked]
    y_sel = y_all[picked]

    total_orig = X_sel.shape[0]
    total_aug = int(round(total_orig * aug_ratio))
    use_geom = "geom" in aug_types.split(",")
    use_opt = "opt" in aug_types.split(",")
    if not (use_geom or use_opt):
        use_geom = True

    geom_quota = total_aug if (use_geom and not use_opt) else 0
    opt_quota = total_aug if (use_opt and not use_geom) else 0
    if use_geom and use_opt:
        geom_quota = total_aug // 2
        opt_quota = total_aug - geom_quota

    rng = np.random.RandomState(seed + 1)

    def sample_indices(count: int) -> np.ndarray:
        if count <= 0:
            return np.empty((0,), dtype=np.int64)
        if count >= total_orig:
            return np.arange(total_orig, dtype=np.int64)
        return rng.choice(total_orig, size=count, replace=False)

    geom_imgs = []
    geom_lbls = []
    if geom_quota > 0:
        idx = sample_indices(geom_quota)
        X_g, _ = geom_module(X_sel[idx], y_sel[idx], per_sample=1)
        geom_imgs.append(X_g)
        geom_lbls.append(y_sel[idx])

    opt_imgs = []
    opt_lbls = []
    if opt_quota > 0:
        idx = sample_indices(opt_quota)
        X_o, _ = opt_module(X_sel[idx], y_sel[idx], per_sample=1)
        opt_imgs.append(X_o)
        opt_lbls.append(y_sel[idx])

    parts_X = [X_sel]
    parts_y = [y_sel]
    if geom_imgs:
        parts_X.append(np.concatenate(geom_imgs, axis=0))
        parts_y.append(np.concatenate(geom_lbls, axis=0))
    if opt_imgs:
        parts_X.append(np.concatenate(opt_imgs, axis=0))
        parts_y.append(np.concatenate(opt_lbls, axis=0))

    X_all_final = np.concatenate(parts_X, axis=0)
    y_all_final = np.concatenate(parts_y, axis=0)

    perm = rng.permutation(len(X_all_final))
    X_all_final = X_all_final[perm]
    y_all_final = y_all_final[perm]

    X_t = torch.from_numpy(X_all_final).float().unsqueeze(1) / 127.5 - 1.0
    y_t = torch.from_numpy(y_all_final).long()
    return X_t, y_t


def build_dataset_for_variant(
    n_per_class: int = 2000,
    aug_types: str = "geom,opt",
    gen_variant: str = "orig",
    seed: int = 42,
) -> Tuple[torch.Tensor, torch.Tensor]:
    gen_variant = (gen_variant or "orig").lower()
    set_seed(seed)
    ds = MNIST(root="./data", train=True, download=True)
    X_all = ds.data.numpy().astype(np.uint8)
    y_all = ds.targets.numpy().astype(np.int64)

    picked = stratified_pick_per_class(
        y_all,
        per_class=n_per_class,
        num_classes=10,
        rng=np.random.RandomState(seed),
    )
    X_sel = X_all[picked]
    y_sel = y_all[picked]

    use_geom = "geom" in (aug_types or "").split(",")
    use_opt = "opt" in (aug_types or "").split(",")
    if not (use_geom or use_opt):
        use_geom = True

    rng = np.random.RandomState(seed + 7)
    def sample_indices(total_count: int, count: int) -> np.ndarray:
        if count <= 0:
            return np.empty((0,), dtype=np.int64)
        if count >= total_count:
            return np.arange(total_count, dtype=np.int64)
        return rng.choice(total_count, size=count, replace=False)

    if gen_variant == "orig":
        X_final = X_sel
        y_final = y_sel
    else:
        if use_geom and use_opt:
            X_g, _ = geom_module(X_sel, y_sel, per_sample=1)
            X_go, _ = opt_module(X_g, y_sel, per_sample=1)
            X_final = X_go
            y_final = y_sel
        elif use_geom:
            X_g, _ = geom_module(X_sel, y_sel, per_sample=1)
            X_final = X_g
            y_final = y_sel
        elif use_opt:
            X_o, _ = opt_module(X_sel, y_sel, per_sample=1)
            X_final = X_o
            y_final = y_sel
        else:
            X_final = X_sel
            y_final = y_sel

    perm = rng.permutation(len(X_final))
    X_final = X_final[perm]
    y_final = y_final[perm]

    X_t = torch.from_numpy(X_final).float().unsqueeze(1) / 127.5 - 1.0
    y_t = torch.from_numpy(y_final).long()
    return X_t, y_t


def build_dataset_for_variant_cifar(
    n_per_class: int = 2000,
    aug_types: str = "geom,opt",
    gen_variant: str = "orig",
    seed: int = 42,
) -> Tuple[torch.Tensor, torch.Tensor]:
    gen_variant = (gen_variant or "orig").lower()
    set_seed(seed)
    ds = CIFAR10(root="./data_src", train=True, download=True)
    X_all = ds.data.astype(np.uint8)
    y_all = np.asarray(ds.targets, dtype=np.int64)

    picked = stratified_pick_per_class(
        y_all,
        per_class=n_per_class,
        num_classes=10,
        rng=np.random.RandomState(seed),
    )
    X_sel = X_all[picked]
    y_sel = y_all[picked]

    use_geom = "geom" in (aug_types or "").split(",")
    use_opt = "opt" in (aug_types or "").split(",")
    if not (use_geom or use_opt):
        use_geom = True

    if gen_variant == "orig":
        X_final = X_sel
        y_final = y_sel
    else:
        if use_geom and use_opt:
            X_g, _ = geom_module(
                X_sel, y_sel, per_sample=1,
                angle_range=(0.0, 0.0),
                translate_frac=0.10,
                scale_range=(0.95, 1.05),
                fill=(123, 116, 103),
            )
            X_go, _ = opt_module(
                X_g, y_sel, per_sample=1,
                brightness_range=(0.8, 1.2),
                contrast_range=(0.8, 1.2),
                blur_kernel=1,
                blur_sigma=(1.0, 1.0),
            )
            X_final = X_go
            y_final = y_sel
        elif use_geom:
            X_g, _ = geom_module(
                X_sel, y_sel, per_sample=1,
                angle_range=(0.0, 0.0),
                translate_frac=0.10,
                scale_range=(0.95, 1.05),
                fill=(123, 116, 103),
            )
            X_final = X_g
            y_final = y_sel
        elif use_opt:
            X_o, _ = opt_module(
                X_sel, y_sel, per_sample=1,
                brightness_range=(0.8, 1.2),
                contrast_range=(0.8, 1.2),
                blur_kernel=1,
                blur_sigma=(1.0, 1.0),
            )
            X_final = X_o
            y_final = y_sel
        else:
            X_final = X_sel
            y_final = y_sel

    rng = np.random.RandomState(seed + 13)
    perm = rng.permutation(len(X_final))
    X_final = X_final[perm]
    y_final = y_final[perm]

    X_t = torch.from_numpy(np.moveaxis(X_final, -1, 1)).float() / 127.5 - 1.0
    y_t = torch.from_numpy(y_final).long()
    return X_t, y_t


class Gen(nn.Module):
    def __init__(self, z_dim: int = 128, num_classes: int = 10, embed_dim: int = 50,
                 out_ch: int = 1, base_hw: int = 7, width_mult: float = 1.0):
        super().__init__()
        self.embed = nn.Embedding(num_classes, embed_dim)
        in_dim = z_dim + embed_dim
        w256 = int(256 * width_mult)
        self.net = nn.Sequential(
            nn.Linear(in_dim, w256 * base_hw * base_hw),
            nn.BatchNorm1d(w256 * base_hw * base_hw),
            nn.ReLU(True),
        )
        c128 = int(128 * width_mult)
        c64 = int(64 * width_mult)
        self.conv = nn.Sequential(
            nn.ConvTranspose2d(w256, c128, 4, 2, 1),
            nn.BatchNorm2d(c128),
            nn.ReLU(True),
            nn.ConvTranspose2d(c128, c64, 4, 2, 1),
            nn.BatchNorm2d(c64),
            nn.ReLU(True),
            nn.Conv2d(c64, out_ch, 3, 1, 1),
            nn.Tanh(),
        )
        self.base_hw = base_hw
        self.w256 = w256

    def forward(self, z: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        e = self.embed(y)
        h = torch.cat([z, e], dim=1)
        h = self.net(h)
        h = h.view(h.size(0), self.w256, self.base_hw, self.base_hw)
        x = self.conv(h)
        return x


def sn(module: nn.Module) -> nn.Module:
    return nn.utils.spectral_norm(module)


class DiscProj(nn.Module):
    def __init__(self, num_classes: int = 10, feat_dim: int = 128, embed_dim: int = 128, in_ch: int = 1, width_mult: float = 1.0):
        super().__init__()
        self.embed = nn.Embedding(num_classes, embed_dim)
        c64 = int(64 * width_mult)
        c128 = int(128 * width_mult)
        fdim = int(feat_dim * width_mult)
        self.conv = nn.Sequential(
            sn(nn.Conv2d(in_ch, c64, 3, 2, 1)),
            nn.LeakyReLU(0.2, inplace=True),
            sn(nn.Conv2d(c64, c128, 3, 2, 1)),
            nn.LeakyReLU(0.2, inplace=True),
            sn(nn.Conv2d(c128, fdim, 3, 1, 1)),
            nn.LeakyReLU(0.2, inplace=True),
        )
        self.lin = sn(nn.Linear(fdim, 1))
        self.proj = sn(nn.Linear(fdim, embed_dim, bias=False))

    def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        h = self.conv(x)
        h = torch.sum(h, dim=(2, 3))
        out = self.lin(h).squeeze(1)
        y_emb = self.embed(y)
        h_proj = self.proj(h)
        out = out + torch.sum(h_proj * y_emb, dim=1)
        return out


def d_hinge(real_scores: torch.Tensor, fake_scores: torch.Tensor) -> torch.Tensor:
    return F.relu(1.0 - real_scores).mean() + F.relu(1.0 + fake_scores).mean()


def g_hinge(fake_scores: torch.Tensor) -> torch.Tensor:
    return (-fake_scores).mean()

def r1_penalty(d_out: torch.Tensor, x_in: torch.Tensor) -> torch.Tensor:
    grad = torch.autograd.grad(
        outputs=d_out.sum(),
        inputs=x_in,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]
    return grad.pow(2).reshape(grad.size(0), -1).sum(dim=1).mean()


def save_config(save_dir: str, cfg: dict) -> None:
    os.makedirs(save_dir, exist_ok=True)
    with open(os.path.join(save_dir, "config.json"), "w") as f:
        json.dump(cfg, f, indent=2)


def train(
    save_root: str,
    n_per_class: int = 5000,
    aug_ratio: float = 1.0,
    aug_types: str = "geom,opt",
    batch_size: int = 128,
    epochs: int = 40,
    z_dim: int = 128,
    lr_g: float = 1e-4,
    lr_d: float = 2e-4,
    betas: Tuple[float, float] = (0.0, 0.9),
    seed: int = 42,
    num_workers: int = 8,
    device_str: Optional[str] = "cuda:0",
    save_samples: bool = False,
    gen_variant: str = "orig",
    dataset: str = "mnist",
) -> None:
    set_seed(seed)
    device = torch.device(device_str) if device_str else torch.device("cuda" if torch.cuda.is_available() else "cpu")

    variant = (gen_variant or "orig").lower()
    if dataset == "mnist":
        channels, height, width, base_hw = 1, 28, 28, 7
        if variant == "orig":
            X, y = build_dataset_for_variant(n_per_class=n_per_class, aug_types=aug_types, gen_variant="orig", seed=seed)
        elif variant == "aug":
            X, y = build_dataset_for_variant(n_per_class=n_per_class, aug_types=aug_types, gen_variant="aug", seed=seed)
        else:
            X, y = build_dataset_for_variant(n_per_class=n_per_class, aug_types=aug_types, gen_variant="orig", seed=seed)
    else:
        channels, height, width, base_hw = 3, 32, 32, 8
        if variant == "orig":
            X, y = build_dataset_for_variant_cifar(n_per_class=n_per_class, aug_types=aug_types, gen_variant="orig", seed=seed)
        elif variant == "aug":
            X, y = build_dataset_for_variant_cifar(n_per_class=n_per_class, aug_types=aug_types, gen_variant="aug", seed=seed)
        else:
            X, y = build_dataset_for_variant_cifar(n_per_class=n_per_class, aug_types=aug_types, gen_variant="orig", seed=seed)
    ds = TensorDataset(X, y)
    dl = DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, drop_last=True)

    set_seed(seed)

    width_mult = 1.0
    if dataset == "cifar10":
        if z_dim == 128:
            z_dim = 256
        width_mult = 1.5

    G = Gen(z_dim=z_dim, out_ch=channels, base_hw=base_hw, width_mult=width_mult).to(device)
    D = DiscProj(in_ch=channels, width_mult=width_mult).to(device)
    opt_g = torch.optim.Adam(G.parameters(), lr=lr_g, betas=betas)
    opt_d = torch.optim.Adam(D.parameters(), lr=lr_d, betas=betas)

    params_root = os.path.join(save_root, "params", dataset, variant)
    samples_root = os.path.join(save_root, "samples", dataset, variant) if save_samples else None
    os.makedirs(params_root, exist_ok=True)
    if save_samples and samples_root is not None:
        os.makedirs(samples_root, exist_ok=True)

    if dataset == "cifar10" and epochs < 60:
        epochs = 60

    cfg = {
        "n_per_class": n_per_class,
        "aug_ratio": aug_ratio,
        "aug_types": aug_types,
        "batch_size": batch_size,
        "epochs": epochs,
        "z_dim": z_dim,
        "lr_g": lr_g,
        "lr_d": lr_d,
        "betas": list(betas),
        "seed": seed,
        "device": str(device),
        "gen_variant": variant,
        "dataset": dataset,
        "channels": channels,
        "height": height,
        "width": width,
        "base_hw": base_hw,
        "width_mult": width_mult,
    }

    iters_per_epoch = math.ceil(len(ds) / batch_size)
    print(f"Start training: N={len(ds)}, iters/epoch={iters_per_epoch}, device={device}")

    gamma_r1 = 10.0
    for epoch in range(1, epochs + 1):
        G.train()
        D.train()
        t0 = time.time()

        for xb, yb in dl:
            xb = xb.to(device, non_blocking=True)
            yb = yb.to(device, non_blocking=True)
            bsz = xb.size(0)

            z = torch.randn(bsz, z_dim, device=device)
            with torch.no_grad():
                x_fake = G(z, yb)

            xb.requires_grad_(True)
            d_real = D(xb, yb)
            d_fake = D(x_fake.detach(), yb)

            loss_d = d_hinge(d_real, d_fake) + (gamma_r1 / 2.0) * r1_penalty(d_real, xb)

            opt_d.zero_grad(set_to_none=True)
            loss_d.backward()
            opt_d.step()
            xb.requires_grad_(False)

            z = torch.randn(bsz, z_dim, device=device)
            x_fake = G(z, yb)
            d_fake = D(x_fake, yb)
            loss_g = g_hinge(d_fake)
            opt_g.zero_grad(set_to_none=True)
            loss_g.backward()
            opt_g.step()

        dt = time.time() - t0
        print(f"[{epoch:03d}/{epochs}] time={dt:.1f}s  loss_d={loss_d.item():.3f}  loss_g={loss_g.item():.3f}")

    ckpt_path = os.path.join(params_root, "checkpoint.pth")
    torch.save(
        {
            "G": G.state_dict(),
            "D": D.state_dict(),
            "config": cfg,
        },
        ckpt_path,
    )
    print(f"Saved GAN checkpoint: {ckpt_path}")


def main():
    parser = argparse.ArgumentParser(description="Train cGAN on MNIST (save params only)")
    parser.add_argument("--save_root", type=str, default=os.path.dirname(__file__))
    parser.add_argument("--n_per_class", type=int, default=5000)
    parser.add_argument("--aug_ratio", type=float, default=1.0)
    parser.add_argument("--batch_size", type=int, default=128)
    parser.add_argument("--epochs", type=int, default=40)
    parser.add_argument("--z_dim", type=int, default=128)
    parser.add_argument("--lr_g", type=float, default=1e-4)
    parser.add_argument("--lr_d", type=float, default=2e-4)
    parser.add_argument("--beta1", type=float, default=0.0)
    parser.add_argument("--beta2", type=float, default=0.9)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--num_workers", type=int, default=8)
    parser.add_argument("--device", type=str, default="cuda:0", help="cuda:0 or cpu")
    parser.add_argument("--gen_variant", type=str, default="orig", choices=["orig", "aug"], help="train on originals only or augmentation-only")
    parser.add_argument("--aug_types", type=str, default="geom,opt", help="augmentation types when gen_variant=aug (comma: geom,opt)")
    parser.add_argument("--dataset", type=str, default="mnist", choices=["mnist","cifar10"])
    args = parser.parse_args()

    train(
        save_root=args.save_root,
        n_per_class=args.n_per_class,
        aug_ratio=args.aug_ratio,
        aug_types=args.aug_types,
        batch_size=args.batch_size,
        epochs=args.epochs,
        z_dim=args.z_dim,
        lr_g=args.lr_g,
        lr_d=args.lr_d,
        betas=(args.beta1, args.beta2),
        seed=args.seed,
        num_workers=args.num_workers,
        device_str=args.device,
        save_samples=False,
        gen_variant=args.gen_variant,
        dataset=args.dataset,
    )


if __name__ == "__main__":
    main()


