import os
import math
import json
import time
from datetime import datetime
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

from ..utils import set_seed as _set_seed, stratified_pick_per_class as _pick_per_class
import numpy as _np
from ..augment import geom_module as _geom_module, opt_module as _opt_module

class ResBlock(nn.Module):
    def __init__(self, inp_ch: int, out_ch: int, mid_ch: Optional[int] = None, residual: bool = False):
        super().__init__()
        self.residual = residual
        if mid_ch is None:
            mid_ch = out_ch
        self.resnet_conv = nn.Sequential(
            nn.Conv2d(inp_ch, mid_ch, kernel_size=3, stride=1, padding=1),
            nn.GroupNorm(8, mid_ch),
            nn.SiLU(),
            nn.Conv2d(mid_ch, out_ch, kernel_size=3, stride=1, padding=1),
            nn.GroupNorm(8, out_ch),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.residual:
            return x + self.resnet_conv(x)
        return self.resnet_conv(x)


class SelfAttentionBlock(nn.Module):
    def __init__(self, channels: int):
        super().__init__()
        self.attn_norm = nn.GroupNorm(num_groups=8, num_channels=channels)
        self.mha = nn.MultiheadAttention(embed_dim=channels, num_heads=4, batch_first=True)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        b, c, h, w = x.shape
        inp_attn = x.reshape(b, c, h * w)
        inp_attn = self.attn_norm(inp_attn)
        inp_attn = inp_attn.transpose(1, 2)
        out_attn, _ = self.mha(inp_attn, inp_attn, inp_attn)
        out_attn = out_attn.transpose(1, 2).reshape(b, c, h, w)
        return x + out_attn


class DownBlock(nn.Module):
    def __init__(self, inp_ch: int, out_ch: int, t_emb_dim: int = 256):
        super().__init__()
        self.down = nn.Sequential(
            nn.MaxPool2d(kernel_size=2, stride=2),
            ResBlock(inp_ch=inp_ch, out_ch=inp_ch, residual=True),
            ResBlock(inp_ch=inp_ch, out_ch=out_ch),
        )
        self.t_emb_layers = nn.Sequential(
            nn.SiLU(),
            nn.Linear(t_emb_dim, out_ch),
        )

    def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        x = self.down(x)
        t_emb = self.t_emb_layers(t)[:, :, None, None].repeat(1, 1, x.shape[2], x.shape[3])
        return x + t_emb


class UpBlock(nn.Module):
    def __init__(self, inp_ch: int, out_ch: int, t_emb_dim: int = 256):
        super().__init__()
        self.upsamp = nn.UpsamplingBilinear2d(scale_factor=2)
        self.up = nn.Sequential(
            ResBlock(inp_ch=inp_ch, out_ch=inp_ch, residual=True),
            ResBlock(inp_ch=inp_ch, out_ch=out_ch, mid_ch=inp_ch // 2),
        )
        self.t_emb_layers = nn.Sequential(
            nn.SiLU(),
            nn.Linear(t_emb_dim, out_ch),
        )

    def forward(self, x: torch.Tensor, skip: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        x = self.upsamp(x)
        x = torch.cat([skip, x], dim=1)
        x = self.up(x)
        t_emb = self.t_emb_layers(t)[:, :, None, None].repeat(1, 1, x.shape[2], x.shape[3])
        return x + t_emb


class Conditional_UNet(nn.Module):
    def __init__(
        self,
        t_emb_dim: int = 256,
        num_classes: int = 10,
        device_str: str = "cuda",
        in_ch: int = 1,
        out_ch: int = 1,
        base_width: int = 64,
        
        enable_sa1: bool = True,
        enable_sa2: bool = True,
        enable_sa3: bool = True,
        enable_sa4: bool = True,
        enable_sa5: bool = True,
        enable_sa6: bool = True,
    ):
        super().__init__()
        self.device_str = device_str
        self.t_emb_dim = t_emb_dim
        self.num_classes = num_classes
        self.in_ch = in_ch
        self.out_ch = out_ch
        bw = int(base_width)

        self.inp = ResBlock(inp_ch=in_ch, out_ch=bw)
        self.down1 = DownBlock(inp_ch=bw, out_ch=bw * 2)
        self.sa1 = SelfAttentionBlock(channels=bw * 2) if enable_sa1 else nn.Identity()
        self.down2 = DownBlock(inp_ch=bw * 2, out_ch=bw * 4)
        self.sa2 = SelfAttentionBlock(channels=bw * 4) if enable_sa2 else nn.Identity()
        self.down3 = DownBlock(inp_ch=bw * 4, out_ch=bw * 4)
        self.sa3 = SelfAttentionBlock(channels=bw * 4) if enable_sa3 else nn.Identity()

        self.lat1 = ResBlock(inp_ch=bw * 4, out_ch=bw * 8)
        self.lat2 = ResBlock(inp_ch=bw * 8, out_ch=bw * 8)
        self.lat3 = ResBlock(inp_ch=bw * 8, out_ch=bw * 4)

        self.up1 = UpBlock(inp_ch=bw * 8, out_ch=bw * 2)
        self.sa4 = SelfAttentionBlock(channels=bw * 2) if enable_sa4 else nn.Identity()
        self.up2 = UpBlock(inp_ch=bw * 4, out_ch=bw)
        self.sa5 = SelfAttentionBlock(channels=bw) if enable_sa5 else nn.Identity()
        self.up3 = UpBlock(inp_ch=bw * 2, out_ch=bw)
        self.sa6 = SelfAttentionBlock(channels=bw) if enable_sa6 else nn.Identity()
        self.out = nn.Conv2d(in_channels=bw, out_channels=out_ch, kernel_size=1)

        self.embeddings = nn.Embedding(num_embeddings=self.num_classes, embedding_dim=self.t_emb_dim)

    def position_embeddings(self, t: torch.Tensor, channels: int) -> torch.Tensor:
        i = 1.0 / (10000 ** (torch.arange(start=0, end=channels, step=2, device=t.device) / channels))
        pos_emb_sin = torch.sin(t.repeat(1, channels // 2) * i)
        pos_emb_cos = torch.cos(t.repeat(1, channels // 2) * i)
        pos_emb = torch.cat([pos_emb_sin, pos_emb_cos], dim=-1)
        return pos_emb

    def forward(self, x: torch.Tensor, t: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        t = t.unsqueeze(1).float()
        t = self.position_embeddings(t, self.t_emb_dim)
        t = t + self.embeddings(labels)

        x1 = self.inp(x)
        x2 = self.down1(x1, t)
        x2 = self.sa1(x2)
        x3 = self.down2(x2, t)
        x3 = self.sa2(x3)
        x4 = self.down3(x3, t)
        x4 = self.sa3(x4)

        x4 = self.lat1(x4)
        x4 = self.lat2(x4)
        x4 = self.lat3(x4)

        x = self.up1(x4, x3, t)
        x = self.sa4(x)
        x = self.up2(x, x2, t)
        x = self.sa5(x)
        x = self.up3(x, x1, t)
        x = self.sa6(x)
        output = self.out(x)
        return output


def linear_beta_schedule(T: int, beta_start: float = 1e-4, beta_end: float = 0.02) -> torch.Tensor:
    return torch.linspace(beta_start, beta_end, T)


def build_mnist_dataloaders(
    batch_size: int,
    num_workers: int,
    image_size: int,
    n_per_class: int,
    seed: int,
) -> DataLoader:
    transform = transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5,), std=(0.5,)),  # [-1, 1]
    ])
    train_ds = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
    targets = train_ds.targets.detach().cpu().numpy().astype("int64")
    if _pick_per_class is not None:
        picked_idx = _pick_per_class(
            targets,
            per_class=n_per_class,
            num_classes=10,
            rng=_np.random.RandomState(seed),
        )
    else:
        rng = _np.random.RandomState(seed)
        picked_idx_list = []
        for c in range(10):
            cls_idx = _np.where(targets == c)[0]
            if len(cls_idx) <= n_per_class:
                picked = cls_idx
            else:
                picked = rng.choice(cls_idx, size=n_per_class, replace=False)
            picked_idx_list.append(picked)
        picked_idx = _np.concatenate(picked_idx_list, axis=0)
        rng.shuffle(picked_idx)
    from torch.utils.data import Subset
    subset = Subset(train_ds, picked_idx.tolist())
    train_dl = DataLoader(subset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=num_workers, pin_memory=True)
    return train_dl


def build_aug_only_loader(
    batch_size: int,
    num_workers: int,
    image_size: int,
    n_per_class: int,
    seed: int,
    aug_types: str = "geom,opt",
) -> DataLoader:
    if _geom_module is None or _opt_module is None:
        return build_mnist_dataloaders(batch_size, num_workers, image_size, n_per_class, seed)

    from torchvision.datasets import MNIST as _MNIST
    ds = _MNIST(root="./data", train=True, download=True)
    X_all = ds.data.numpy().astype("uint8")
    y_all = ds.targets.numpy().astype("int64")

    if _pick_per_class is not None:
        picked_idx = _pick_per_class(
            y_all,
            per_class=n_per_class,
            num_classes=10,
            rng=_np.random.RandomState(seed),
        )
    else:
        rngp = _np.random.RandomState(seed)
        idx_list = []
        for c in range(10):
            cls_idx = _np.where(y_all == c)[0]
            if len(cls_idx) <= n_per_class:
                picked = cls_idx
            else:
                picked = rngp.choice(cls_idx, size=n_per_class, replace=False)
            idx_list.append(picked)
        picked_idx = _np.concatenate(idx_list, axis=0)
        rngp.shuffle(picked_idx)

    X_sel = X_all[picked_idx]
    y_sel = y_all[picked_idx]

    use_geom = "geom" in (aug_types or "").split(",")
    use_opt = "opt" in (aug_types or "").split(",")
    if not (use_geom or use_opt):
        use_geom = True

    if use_geom and use_opt:
        Xg, _ = _geom_module(X_sel, y_sel, per_sample=1)
        Xgo, _ = _opt_module(Xg, y_sel, per_sample=1)
        X_aug = Xgo
        y_aug = y_sel
    elif use_geom:
        Xg, _ = _geom_module(X_sel, y_sel, per_sample=1)
        X_aug = Xg
        y_aug = y_sel
    elif use_opt:
        Xo, _ = _opt_module(X_sel, y_sel, per_sample=1)
        X_aug = Xo
        y_aug = y_sel
    else:
        X_aug = X_sel
        y_aug = y_sel

    rng = _np.random.RandomState(seed + 13)
    perm = rng.permutation(len(X_aug))
    X_aug = X_aug[perm]
    y_aug = y_aug[perm]

    import torch as _torch
    from torch.utils.data import TensorDataset as _TensorDataset
    X_t = _torch.from_numpy(X_aug).float().unsqueeze(1) / 127.5 - 1.0
    if image_size != 28:
        import torch.nn.functional as _F
        X_t = _F.interpolate(X_t, size=(image_size, image_size), mode="bilinear", align_corners=False)
    y_t = _torch.from_numpy(y_aug).long()
    ds_t = _TensorDataset(X_t, y_t)
    return DataLoader(ds_t, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=num_workers, pin_memory=True)


def build_cifar_dataloader(
    batch_size: int,
    num_workers: int,
    image_size: int,
    n_per_class: int,
    seed: int,
    gen_variant: str = "orig",
    aug_types: str = "geom,opt",
):
    from torchvision.datasets import CIFAR10
    transform = transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5,0.5,0.5), std=(0.5,0.5,0.5)),  # [-1,1]
    ])
    ds = CIFAR10(root="./data", train=True, download=True, transform=None)
    X_all = ds.data.astype("uint8")  # (N,32,32,3)
    y_all = _np.asarray(ds.targets, dtype="int64")
    if _pick_per_class is not None:
        picked_idx = _pick_per_class(y_all, per_class=n_per_class, num_classes=10, rng=_np.random.RandomState(seed))
    else:
        rng = _np.random.RandomState(seed)
        idxs = []
        for c in range(10):
            cls_idx = _np.where(y_all == c)[0]
            picked = cls_idx if len(cls_idx) <= n_per_class else rng.choice(cls_idx, size=n_per_class, replace=False)
            idxs.append(picked)
        picked_idx = _np.concatenate(idxs, axis=0)
        rng.shuffle(picked_idx)
    X_sel = X_all[picked_idx]
    y_sel = y_all[picked_idx]
    variant = (gen_variant or "orig").lower()
    if variant == "aug" and (_geom_module is not None and _opt_module is not None):
        use_geom = "geom" in (aug_types or "").split(",")
        use_opt = "opt" in (aug_types or "").split(",")
        if not (use_geom or use_opt):
            use_geom = True
        if use_geom and use_opt:
            Xg, _ = _geom_module(
                X_sel, y_sel, per_sample=1,
                angle_range=(0.0, 0.0),
                translate_frac=0.10,
                scale_range=(0.95, 1.05),
                fill=(123, 116, 103),
            )
            Xgo, _ = _opt_module(
                Xg, y_sel, per_sample=1,
                brightness_range=(0.8, 1.2),
                contrast_range=(0.8, 1.2),
                blur_kernel=1,
                blur_sigma=(1.0, 1.0),
            )
            X_sel = Xgo
            y_sel = y_sel
        elif use_geom:
            Xg, _ = _geom_module(
                X_sel, y_sel, per_sample=1,
                angle_range=(0.0, 0.0),
                translate_frac=0.10,
                scale_range=(0.95, 1.05),
                fill=(123, 116, 103),
            )
            X_sel = Xg
        elif use_opt:
            Xo, _ = _opt_module(
                X_sel, y_sel, per_sample=1,
                brightness_range=(0.8, 1.2),
                contrast_range=(0.8, 1.2),
                blur_kernel=1,
                blur_sigma=(1.0, 1.0),
            )
            X_sel = Xo
    from PIL import Image
    class _NumpyCIFAR(torch.utils.data.Dataset):
        def __init__(self, X, y, tf):
            self.X = X
            self.y = y
            self.tf = tf
        def __len__(self):
            return len(self.X)
        def __getitem__(self, i):
            img = Image.fromarray(self.X[i], mode="RGB")
            x = self.tf(img)
            return x, int(self.y[i])
    ds_t = _NumpyCIFAR(X_sel, y_sel, transform)
    return DataLoader(ds_t, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=num_workers, pin_memory=True)

def train(
    save_root: str,
    batch_size: int = 64,
    epochs: int = 30,
    T: int = 1000,
    lr: float = 1e-4,
    beta_start: float = 1e-4,
    beta_end: float = 0.02,
    image_size: int = 32,
    n_per_class: int = 5000,
    seed: int = 42,
    num_workers: int = 8,
    device_str: Optional[str] = None,
    gen_variant: str = "orig",
    dataset: str = "mnist",
    aug_types: str = "geom,opt",
) -> None:
    if _set_seed is not None:
        _set_seed(seed)
    else:
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

    device = torch.device(device_str) if device_str else torch.device("cuda" if torch.cuda.is_available() else "cpu")

    variant = (gen_variant or "orig").lower()
    if dataset == "mnist":
        channels = 1
        base_width = 64
        if variant == "orig":
            dl = build_mnist_dataloaders(
                batch_size=batch_size,
                num_workers=num_workers,
                image_size=image_size,
                n_per_class=n_per_class,
                seed=seed,
            )
        else:
            dl = build_aug_only_loader(
                batch_size=batch_size,
                num_workers=num_workers,
                image_size=image_size,
                n_per_class=n_per_class,
                seed=seed,
                aug_types=aug_types,
            )
    else:
        channels = 3
        base_width = 128
        dl = build_cifar_dataloader(
            batch_size=batch_size,
            num_workers=num_workers,
            image_size=image_size,
            n_per_class=n_per_class,
            seed=seed,
            gen_variant=variant,
            aug_types=aug_types,
        )

    if dataset == "cifar10":
        model = Conditional_UNet(
            t_emb_dim=256,
            num_classes=10,
            device_str=str(device),
            in_ch=channels,
            out_ch=channels,
            base_width=base_width,
            enable_sa1=True,
            enable_sa2=True,
            enable_sa3=False,
            enable_sa4=True,
            enable_sa5=True,
            enable_sa6=False,
        ).to(device)
    else:
        model = Conditional_UNet(t_emb_dim=256, num_classes=10, device_str=str(device), in_ch=channels, out_ch=channels, base_width=base_width).to(device)
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.MSELoss()

    betas_t = linear_beta_schedule(T=T, beta_start=beta_start, beta_end=beta_end).to(device)
    alphas = 1.0 - betas_t
    alphas_cumprod = torch.cumprod(alphas, dim=0)

    ds_name = dataset if dataset in ("mnist", "cifar10") else ("mnist" if channels == 1 else "cifar10")
    params_root = os.path.join(save_root, "params", ds_name, variant)
    os.makedirs(params_root, exist_ok=True)

    if dataset == "cifar10" and epochs == 20:
        epochs = 200
    if dataset == "cifar10" and epochs == 30:
        epochs = 200

    cfg = {
        "batch_size": batch_size,
        "epochs": epochs,
        "T": T,
        "lr": lr,
        "beta_start": beta_start,
        "beta_end": beta_end,
        "image_size": image_size,
        "n_per_class": n_per_class,
        "seed": seed,
        "device": str(device),
        "num_classes": 10,
        "t_emb_dim": 256,
        "arch": "Conditional_UNet(attn, 32x32, base64-256)",
        "schedule": "linear",
        "gen_variant": variant,
        "dataset": dataset,
        "channels": channels,
        "base_width": base_width,
        "aug_types": aug_types,
    }

    print(f"Start Diffusion-new (notebook-arch) training: iters/epoch≈?", flush=True)
    for ep in range(1, epochs + 1):
        model.train()
        t0 = time.time()
        running_loss = 0.0
        num_batches = 0

        for xb, yb in dl:
            xb = xb.to(device, non_blocking=True)
            yb = yb.to(device, non_blocking=True)

            batch_size_eff = xb.size(0)
            t_int = torch.randint(low=0, high=T, size=(batch_size_eff,), device=device, dtype=torch.long)
            a_bar_t = alphas_cumprod[t_int].view(batch_size_eff, 1, 1, 1)
            noise = torch.randn_like(xb)
            x_t = torch.sqrt(a_bar_t) * xb + torch.sqrt(1.0 - a_bar_t) * noise

            pred = model(x_t, t_int.float(), yb)
            loss = loss_fn(noise, pred)

            opt.zero_grad(set_to_none=True)
            loss.backward()
            opt.step()

            running_loss += loss.item()
            num_batches += 1

        dt = time.time() - t0
        avg_loss = running_loss / max(1, num_batches)
        print(f"[{ep:03d}/{epochs}] time={dt:.1f}s loss={avg_loss:.4f}", flush=True)

    ckpt_path = os.path.join(params_root, "checkpoint.pth")
    torch.save(
        {
            "model": model.state_dict(),
            "betas": betas_t.cpu(),
            "config": cfg,
        },
        ckpt_path,
    )
    print(f"Saved checkpoint: {ckpt_path}", flush=True)


def main():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--save_root", type=str, default=os.path.dirname(__file__))
    parser.add_argument("--batch_size", type=int, default=64)
    parser.add_argument("--epochs", type=int, default=30)
    parser.add_argument("--T", type=int, default=1000)
    parser.add_argument("--lr", type=float, default=1e-4)
    parser.add_argument("--beta_start", type=float, default=1e-4)
    parser.add_argument("--beta_end", type=float, default=0.02)
    parser.add_argument("--image_size", type=int, default=32)
    parser.add_argument("--n_per_class", type=int, default=5000)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--num_workers", type=int, default=8)
    parser.add_argument("--device", type=str, default="cuda:0")
    parser.add_argument("--gen_variant", type=str, default="orig", choices=["orig", "aug"])
    parser.add_argument("--aug_types", type=str, default="geom,opt")
    parser.add_argument("--dataset", type=str, default="mnist", choices=["mnist","cifar10"])
    args = parser.parse_args()

    train(
        save_root=args.save_root,
        batch_size=args.batch_size,
        epochs=args.epochs,
        T=args.T,
        lr=args.lr,
        beta_start=args.beta_start,
        beta_end=args.beta_end,
        image_size=args.image_size,
        n_per_class=args.n_per_class,
        seed=args.seed,
        num_workers=args.num_workers,
        device_str=args.device,
        gen_variant=args.gen_variant,
        dataset=args.dataset,
        aug_types=args.aug_types,
    )


if __name__ == "__main__":
    main()


