import math
from pathlib import Path
from random import random
from functools import partial
from collections import namedtuple
from typing import Literal, List, Optional, Sequence, Tuple
from dataclasses import dataclass

import h5py
import numpy as np
import torch
from torch import Tensor
from torch import nn
import torch.nn.functional as F
from torch.nn import Module, ModuleList
from torch.amp import autocast
from torch.utils.data import Dataset, DataLoader, IterableDataset, get_worker_info
from torch.optim import Adam
from torch.utils.checkpoint import checkpoint

from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange
from tqdm.auto import tqdm
from ema_pytorch import EMA
from accelerate import Accelerator

from denoising_diffusion_pytorch.attend import Attend
from denoising_diffusion_pytorch.version import __version__

@dataclass
class ControlConfig:
    dt_max: int = 10
    dt_ids: Optional[List[int]] = None
    num_future: int = 2
    dt_strategy: Literal['uniform', 'random'] = 'uniform'
    obs_mode: Literal['mask_full', 'mask_regular', 'mask_random', 'downsample'] = 'mask_regular'
    mask_stride: int = 2              # stride for regular mask
    downsample_factor: int = 4         # spatial downsampling factor
    mask_prob: float = 0.1            # prob of keeping a pixel for random mask
    hdf5_swmr: bool = True

class ControlBatchIterableDataset(IterableDataset):
    """
    Yields full batches so we can resample dt_ids *once per batch*.
    Compatible with multi-worker; each worker iterates over a disjoint slice of indices.
    """
    def __init__(self, h5_file: str, config: ControlConfig, batch_size: int, shuffle: bool = True):
        super().__init__()
        self.config = config
        self.h5_file = h5_file
        self.batch_size = int(batch_size)
        self.shuffle = bool(shuffle)

        # stats (loaded in main; tensors copied in workers)
        # Ideally passed as args, hardcoded for now as in original
        stats = torch.load('kolm_stats.pt')
        self.min = stats['min'].view(1, -1, 1, 1)
        self.max = stats['max'].view(1, -1, 1, 1)

    def _open_h5(self):
        return h5py.File(self.h5_file, 'r', libver='latest', swmr=self.config.hdf5_swmr)

    def _resample_dt_ids_for_batch(self, rng: np.random.Generator):
        cfg = self.config
        if cfg.dt_ids:
            return list(cfg.dt_ids)

        if cfg.dt_strategy == 'uniform':
            return list(np.linspace(1, cfg.dt_max, cfg.num_future, endpoint=False, dtype=int))

        assert cfg.num_future >= 1
        pool = np.arange(1, max(1, cfg.dt_max))
        k = max(0, cfg.num_future - 1)
        if len(pool) >= k and k > 0:
            chosen = rng.choice(pool, size=k, replace=False).tolist()
        else:
            chosen = []
        chosen.append(cfg.dt_max)
        chosen = sorted(set(int(x) for x in chosen))
        while len(chosen) < cfg.num_future:
            for v in range(1, cfg.dt_max):
                if v not in chosen:
                    chosen.insert(-1, v)
                    break
            else:
                break
        return chosen

    def _make_mask_and_obs(self, tgt: torch.Tensor, rng: np.random.Generator):
        cfg = self.config
        C, H, W = tgt.shape

        if cfg.obs_mode == 'mask_full':
            mask = torch.ones_like(tgt)
            obs  = tgt
        elif cfg.obs_mode == 'mask_regular':
            mask = torch.zeros_like(tgt)
            mask[..., ::cfg.mask_stride, ::cfg.mask_stride] = 1.0
            obs  = tgt * mask
        elif cfg.obs_mode == 'mask_random':
            keep = (torch.rand(1, H, W) < float(cfg.mask_prob)).float()
            if keep.sum() == 0:
                keep.view(-1)[torch.randint(0, H * W, (1,))] = 1.0
            mask = keep.expand(C, -1, -1).contiguous()
            obs  = tgt * mask
        else:  # 'downsample'
            f = int(cfg.downsample_factor)
            mask = torch.ones_like(tgt)
            ds = F.avg_pool2d(tgt.unsqueeze(0), kernel_size=f, stride=f)[0]
            obs = F.interpolate(ds.unsqueeze(0), size=(H, W), mode='nearest')[0]

        return obs, mask

    def _iter_index_range(self, N_total, worker_info):
        if worker_info is None:
            idxs = list(range(N_total))
        else:
            wid = worker_info.id
            nworkers = worker_info.num_workers
            idxs = list(range(wid, N_total, nworkers))
        if self.shuffle:
            rng = np.random.default_rng()
            rng.shuffle(idxs)
        return idxs

    def __iter__(self):
        worker_info = get_worker_info()
        h5 = self._open_h5()
        data = h5['x']
        N_total = data.shape[0]
        device = torch.device('cpu')
        rng = np.random.default_rng()
        idxs = self._iter_index_range(N_total, worker_info)

        batch_prev, batch_obs, batch_mask, batch_dt, batch_true = [], [], [], [], []
        i_ptr = 0

        while i_ptr < len(idxs):
            dt_ids = self._resample_dt_ids_for_batch(rng)
            max_dt = max(dt_ids)
            
            batch_prev.clear(); batch_obs.clear(); batch_mask.clear(); batch_dt.clear(); batch_true.clear()

            for _ in range(self.batch_size):
                if i_ptr >= len(idxs):
                    break
                idx = idxs[i_ptr]; i_ptr += 1

                seq_np = data[idx]
                seq = torch.from_numpy(seq_np).float()
                seq = (seq - self.min) / (self.max - self.min)
                
                T, C, H, W = seq.shape
                t0 = rng.integers(0, T - max_dt)
                prev = seq[t0]

                frames = []; masks = []; trues = []
                for dt in dt_ids:
                    tgt = seq[t0 + dt]
                    obs, mask = self._make_mask_and_obs(tgt, rng)
                    frames.append(obs)
                    masks.append(mask)
                    trues.append(tgt)

                obs_tensor  = torch.stack(frames, dim=0)
                mask_tensor = torch.stack(masks,  dim=0)
                true_tensor = torch.stack(trues,  dim=0)
                dt_tensor   = torch.tensor(dt_ids, dtype=torch.long)

                batch_prev.append(prev)
                batch_obs.append(obs_tensor)
                batch_mask.append(mask_tensor)
                batch_dt.append(dt_tensor)
                batch_true.append(true_tensor)

            if len(batch_prev) == 0:
                break

            prev_b  = torch.stack(batch_prev,  dim=0)
            obs_b   = torch.stack(batch_obs,   dim=0)
            mask_b  = torch.stack(batch_mask,  dim=0)
            dt_b    = torch.stack(batch_dt,    dim=0)
            true_b  = torch.stack(batch_true,  dim=0)

            yield prev_b.to(device), obs_b.to(device), mask_b.to(device), dt_b.to(device), true_b.to(device)

        try:
            h5.close()
        except Exception:
            pass

class ControlNet(nn.Module):
    def __init__(self, channels, image_size: tuple[int,int], hid=64, emb_dim=32):
        super().__init__()
        self.H, self.W = image_size
        self.hid = hid
        self.emb_dim = emb_dim

        in_ch = channels * 5  # x_d, x_cond, obs, mask, last_u
        self.encoder = nn.Sequential(
            nn.Conv2d(in_ch, hid, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(hid, hid, 3, padding=1),
            nn.ReLU(),
        )
        self.enc_norm = nn.GroupNorm(num_groups=8, num_channels=hid, eps=1e-5, affine=True)
        
        # multi-scale down/up
        self.down = nn.Sequential(
            nn.Conv2d(hid, hid, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(hid, hid, 3, padding=1),
            nn.ReLU(),
        )
        self.up_proj = nn.Sequential(
            nn.ReLU(),
            nn.Conv2d(hid, hid, 3, padding=1),
        )
        self.fuse_conv = nn.Conv2d(hid, hid, 1)
        self.fused_norm = nn.GroupNorm(num_groups=8, num_channels=hid, eps=1e-5, affine=True)

        # FiLM: dt, tau, snr embeddings -> gamma/beta
        self.dt_mlp = nn.Sequential(nn.Linear(1, hid), nn.ReLU(), nn.Linear(hid, hid))
        self.tau_mlp = nn.Sequential(nn.Linear(1, hid), nn.ReLU(), nn.Linear(hid, hid))
        self.snr_mlp = nn.Sequential(nn.Linear(1, hid), nn.ReLU(), nn.Linear(hid, hid))
        
        self.film_mapper = nn.Linear(hid * 3, hid * 2)
        self.delta_proj = nn.Sequential(
            nn.Conv2d(hid, hid, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(hid, channels, 3, padding=1)
        )
        self.lastu_norm = nn.GroupNorm(num_groups=1, num_channels=channels, eps=1e-5, affine=True)

    def forward(
        self,
        x_d: torch.Tensor, x_cond: torch.Tensor, obs: torch.Tensor, mask: torch.Tensor,
        dt: torch.Tensor, max_dt: int, last_u: torch.Tensor, snr: torch.Tensor, frame_i: int
    ) -> torch.Tensor:
        B, C, H, W = x_d.shape
        device = x_d.device

        obs_flat = obs[:, 0]
        mask_flat = mask[:, 0]

        backbone_in = torch.cat([x_d, x_cond, obs_flat, mask_flat, last_u], dim=1)
        feat = self.encoder(backbone_in)
        feat = self.enc_norm(feat)

        down = self.down(feat)
        up = torch.nn.functional.interpolate(down, size=(H, W), mode='nearest')
        up = self.up_proj(up)
        fused = self.fuse_conv(feat + up)
        fused = self.fused_norm(fused)

        dt_norm = (dt / max_dt).clamp(0.0, 1.0)
        tau_norm = torch.full((B, 1), frame_i / float(max_dt), device=device).clamp(0.0, 1.0)
        
        if snr.dim() == 0:
            snr_tensor = snr.unsqueeze(0).expand(B)
        elif snr.dim() == 1:
            snr_tensor = snr if snr.shape[0] == B else snr.expand(B)
        else:
            snr_tensor = snr.flatten()[:B]
        snr_in = snr_tensor.unsqueeze(-1)

        dt_emb = self.dt_mlp(dt_norm.float())
        tau_emb = self.tau_mlp(tau_norm.float())
        snr_emb = self.snr_mlp(snr_in.float())
        film_in = torch.cat([dt_emb, tau_emb, snr_emb], dim=1)
        
        gamma_beta = self.film_mapper(film_in)
        gamma, beta = gamma_beta.chunk(2, dim=1)
        gamma = gamma.unsqueeze(-1).unsqueeze(-1)
        beta = beta.unsqueeze(-1).unsqueeze(-1)

        modulated = fused * (1 + gamma) + beta
        delta = self.delta_proj(modulated)
        lu = self.lastu_norm(last_u)
        return lu + delta

ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start'])
SensorRecord = namedtuple('SensorRecord', ['frame','mask','dt'])

# helpers functions

def exists(x):
    return x is not None

def default(val, d):
    if exists(val):
        return val
    return d() if callable(d) else d

def cast_tuple(t, length = 1):
    if isinstance(t, tuple):
        return t
    return ((t,) * length)

def divisible_by(numer, denom):
    return (numer % denom) == 0

def identity(t, *args, **kwargs):
    return t

def cycle(dl):
    while True:
        for data in dl:
            yield data

def has_int_squareroot(num):
    return (math.sqrt(num) ** 2) == num

def num_to_groups(num, divisor):
    groups = num // divisor
    remainder = num % divisor
    arr = [divisor] * groups
    if remainder > 0:
        arr.append(remainder)
    return arr

def convert_image_to_fn(img_type, image):
    if image.mode != img_type:
        return image.convert(img_type)
    return image

def normalize_to_neg_one_to_one(img):
    return img * 2 - 1

def unnormalize_to_zero_to_one(t):
    return (t + 1) * 0.5

# small helper modules

def Upsample(dim, dim_out = None):
    return nn.Sequential(
        nn.Upsample(scale_factor = 2, mode = 'nearest'),
        nn.Conv2d(dim, default(dim_out, dim), 3, padding = 1)
    )

def Downsample(dim, dim_out = None):
    return nn.Sequential(
        Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1 = 2, p2 = 2),
        nn.Conv2d(dim * 4, default(dim_out, dim), 1)
    )

class RMSNorm(Module):
    def __init__(self, dim):
        super().__init__()
        self.scale = dim ** 0.5
        self.g = nn.Parameter(torch.ones(1, dim, 1, 1))

    def forward(self, x):
        return F.normalize(x, dim = 1) * self.g * self.scale

class SinusoidalPosEmb(Module):
    def __init__(self, dim, theta = 10000):
        super().__init__()
        self.dim = dim
        self.theta = theta

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(self.theta) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb

class RandomOrLearnedSinusoidalPosEmb(Module):
    def __init__(self, dim, is_random = False):
        super().__init__()
        assert divisible_by(dim, 2)
        half_dim = dim // 2
        self.weights = nn.Parameter(torch.randn(half_dim), requires_grad = not is_random)

    def forward(self, x):
        x = rearrange(x, 'b -> b 1')
        freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi
        fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1)
        fouriered = torch.cat((x, fouriered), dim = -1)
        return fouriered

class Block(Module):
    def __init__(self, dim, dim_out, dropout = 0.):
        super().__init__()
        self.proj = nn.Conv2d(dim, dim_out, 3, padding = 1)
        self.norm = RMSNorm(dim_out)
        self.act = nn.SiLU()
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, scale_shift = None):
        x = self.proj(x)
        x = self.norm(x)
        if exists(scale_shift):
            scale, shift = scale_shift
            x = x * (scale + 1) + shift
        x = self.act(x)
        return self.dropout(x)

class ResnetBlock(Module):
    def __init__(self, dim, dim_out, *, time_emb_dim = None, dropout = 0.):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.SiLU(),
            nn.Linear(time_emb_dim, dim_out * 2)
        ) if exists(time_emb_dim) else None

        self.block1 = Block(dim, dim_out, dropout = dropout)
        self.block2 = Block(dim_out, dim_out)
        self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

    def forward(self, x, time_emb = None):
        scale_shift = None
        if exists(self.mlp) and exists(time_emb):
            time_emb = self.mlp(time_emb)
            time_emb = rearrange(time_emb, 'b c -> b c 1 1')
            scale_shift = time_emb.chunk(2, dim = 1)

        h = self.block1(x, scale_shift = scale_shift)
        h = self.block2(h)
        return h + self.res_conv(x)

class LinearAttention(Module):
    def __init__(self, dim, heads = 4, dim_head = 32, num_mem_kv = 4):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = dim_head * heads

        self.norm = RMSNorm(dim)
        self.mem_kv = nn.Parameter(torch.randn(2, heads, dim_head, num_mem_kv))
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
        self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), RMSNorm(dim))

    def forward(self, x):
        b, c, h, w = x.shape
        x = self.norm(x)
        qkv = self.to_qkv(x).chunk(3, dim = 1)
        q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)
        mk, mv = map(lambda t: repeat(t, 'h c n -> b h c n', b = b), self.mem_kv)
        k, v = map(partial(torch.cat, dim = -1), ((mk, k), (mv, v)))
        q = q.softmax(dim = -2)
        k = k.softmax(dim = -1)
        q = q * self.scale
        context = torch.einsum('b h d n, b h e n -> b h d e', k, v)
        out = torch.einsum('b h d e, b h d n -> b h e n', context, q)
        out = rearrange(out, 'b h c (x y) -> b (h c) x y', h = self.heads, x = h, y = w)
        return self.to_out(out)

class Attention(Module):
    def __init__(self, dim, heads = 4, dim_head = 32, num_mem_kv = 4, flash = False):
        super().__init__()
        self.heads = heads
        hidden_dim = dim_head * heads
        self.norm = RMSNorm(dim)
        self.attend = Attend(flash = flash)
        self.mem_kv = nn.Parameter(torch.randn(2, heads, num_mem_kv, dim_head))
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
        self.to_out = nn.Conv2d(hidden_dim, dim, 1)

    def forward(self, x):
        b, c, h, w = x.shape
        x = self.norm(x)
        qkv = self.to_qkv(x).chunk(3, dim = 1)
        q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h (x y) c', h = self.heads), qkv)
        mk, mv = map(lambda t: repeat(t, 'h n d -> b h n d', b = b), self.mem_kv)
        k, v = map(partial(torch.cat, dim = -2), ((mk, k), (mv, v)))
        out = self.attend(q, k, v)
        out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w)
        return self.to_out(out)

class Unet(Module):
    def __init__(
        self, dim, init_dim = None, out_dim = None, dim_mults = (1, 2, 4, 8), channels = 3,
        self_condition = False, learned_variance = False, learned_sinusoidal_cond = False,
        random_fourier_features = False, learned_sinusoidal_dim = 16, sinusoidal_pos_emb_theta = 10000,
        dropout = 0., attn_dim_head = 32, attn_heads = 4, full_attn = None, flash_attn = False
    ):
        super().__init__()
        self.channels = channels
        self.self_condition = self_condition
        input_channels = channels + (channels // 2) if self_condition else channels
        init_dim = default(init_dim, dim)
        self.init_conv = nn.Conv2d(input_channels, init_dim, 7, padding = 3)
        dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
        in_out = list(zip(dims[:-1], dims[1:]))
        time_dim = dim * 4
        self.random_or_learned_sinusoidal_cond = learned_sinusoidal_cond or random_fourier_features

        if self.random_or_learned_sinusoidal_cond:
            sinu_pos_emb = RandomOrLearnedSinusoidalPosEmb(learned_sinusoidal_dim, random_fourier_features)
            fourier_dim = learned_sinusoidal_dim + 1
        else:
            sinu_pos_emb = SinusoidalPosEmb(dim, theta = sinusoidal_pos_emb_theta)
            fourier_dim = dim

        self.time_mlp = nn.Sequential(
            sinu_pos_emb,
            nn.Linear(fourier_dim, time_dim),
            nn.GELU(),
            nn.Linear(time_dim, time_dim)
        )

        if not full_attn:
            full_attn = (*((False,) * (len(dim_mults) - 1)), True)
        num_stages = len(dim_mults)
        full_attn  = cast_tuple(full_attn, num_stages)
        attn_heads = cast_tuple(attn_heads, num_stages)
        attn_dim_head = cast_tuple(attn_dim_head, num_stages)

        FullAttention = partial(Attention, flash = flash_attn)
        resnet_block = partial(ResnetBlock, time_emb_dim = time_dim, dropout = dropout)

        self.downs = ModuleList([])
        self.ups = ModuleList([])
        num_resolutions = len(in_out)

        for ind, ((dim_in, dim_out), layer_full_attn, layer_attn_heads, layer_attn_dim_head) in enumerate(zip(in_out, full_attn, attn_heads, attn_dim_head)):
            is_last = ind >= (num_resolutions - 1)
            attn_klass = FullAttention if layer_full_attn else LinearAttention
            self.downs.append(ModuleList([
                resnet_block(dim_in, dim_in),
                resnet_block(dim_in, dim_in),
                attn_klass(dim_in, dim_head = layer_attn_dim_head, heads = layer_attn_heads),
                Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding = 1)
            ]))

        mid_dim = dims[-1]
        self.mid_block1 = resnet_block(mid_dim, mid_dim)
        self.mid_attn = FullAttention(mid_dim, heads = attn_heads[-1], dim_head = attn_dim_head[-1])
        self.mid_block2 = resnet_block(mid_dim, mid_dim)

        for ind, ((dim_in, dim_out), layer_full_attn, layer_attn_heads, layer_attn_dim_head) in enumerate(zip(*map(reversed, (in_out, full_attn, attn_heads, attn_dim_head)))):
            is_last = ind == (len(in_out) - 1)
            attn_klass = FullAttention if layer_full_attn else LinearAttention
            self.ups.append(ModuleList([
                resnet_block(dim_out + dim_in, dim_out),
                resnet_block(dim_out + dim_in, dim_out),
                attn_klass(dim_out, dim_head = layer_attn_dim_head, heads = layer_attn_heads),
                Upsample(dim_out, dim_in) if not is_last else  nn.Conv2d(dim_out, dim_in, 3, padding = 1)
            ]))

        default_out_dim = channels//2 * (1 if not learned_variance else 2)
        self.out_dim = default(out_dim, default_out_dim)
        self.final_res_block = resnet_block(init_dim * 2, init_dim)
        self.final_conv = nn.Conv2d(init_dim, self.out_dim, 1)

    @property
    def downsample_factor(self):
        return 2 ** (len(self.downs) - 1)

    def forward(self, x, time, x_self_cond = None):
        assert all([divisible_by(d, self.downsample_factor) for d in x.shape[-2:]])
        if self.self_condition:
            x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x))
            x = torch.cat((x_self_cond, x), dim = 1)

        x = self.init_conv(x)
        r = x.clone()
        t = self.time_mlp(time)
        h = []

        for block1, block2, attn, downsample in self.downs:
            x = block1(x, t)
            h.append(x)
            x = block2(x, t)
            x = attn(x) + x
            h.append(x)
            x = downsample(x)

        x = self.mid_block1(x, t)
        x = self.mid_attn(x) + x
        x = self.mid_block2(x, t)

        for block1, block2, attn, upsample in self.ups:
            x = torch.cat((x, h.pop()), dim = 1)
            x = block1(x, t)
            x = torch.cat((x, h.pop()), dim = 1)
            x = block2(x, t)
            x = attn(x) + x
            x = upsample(x)

        x = torch.cat((x, r), dim = 1)
        x = self.final_res_block(x, t)
        return self.final_conv(x)

# gaussian diffusion trainer class

def extract(a, t, x_shape):
    b, *_ = t.shape
    out = a.gather(-1, t)
    return out.reshape(b, *((1,) * (len(x_shape) - 1)))

def linear_beta_schedule(timesteps):
    scale = 1000 / timesteps
    beta_start = scale * 0.0001
    beta_end = scale * 0.02
    return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64)

def cosine_beta_schedule(timesteps, s = 0.008):
    steps = timesteps + 1
    t = torch.linspace(0, timesteps, steps, dtype = torch.float64) / timesteps
    alphas_cumprod = torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0, 0.999)

def sigmoid_beta_schedule(timesteps, start = -3, end = 3, tau = 1, clamp_min = 1e-5):
    steps = timesteps + 1
    t = torch.linspace(0, timesteps, steps, dtype = torch.float64) / timesteps
    v_start = torch.tensor(start / tau).sigmoid()
    v_end = torch.tensor(end / tau).sigmoid()
    alphas_cumprod = (-((t * (end - start) + start) / tau).sigmoid() + v_end) / (v_end - v_start)
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0, 0.999)

class GaussianDiffusion(Module):
    def __init__(
        self, trunk, *, image_size, timesteps = 1000, sampling_timesteps = None,
        objective = 'pred_v', beta_schedule = 'sigmoid', schedule_fn_kwargs = dict(),
        ddim_sampling_eta = 0., auto_normalize = True, offset_noise_strength = 0.,
        min_snr_loss_weight = False, min_snr_gamma = 5
    ):
        super().__init__()
        self.trunk = trunk
        self.channels = self.trunk.channels
        self.self_condition = self.trunk.self_condition

        if isinstance(image_size, int):
            image_size = (image_size, image_size)
        self.image_size = image_size
        self.objective = objective

        if beta_schedule == 'linear':
            beta_schedule_fn = linear_beta_schedule
        elif beta_schedule == 'cosine':
            beta_schedule_fn = cosine_beta_schedule
        elif beta_schedule == 'sigmoid':
            beta_schedule_fn = sigmoid_beta_schedule
        else:
            raise ValueError(f'unknown beta schedule {beta_schedule}')

        betas = beta_schedule_fn(timesteps, **schedule_fn_kwargs)
        alphas = 1. - betas
        alphas_cumprod = torch.cumprod(alphas, dim=0)
        alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)

        timesteps, = betas.shape
        self.num_timesteps = int(timesteps)
        self.sampling_timesteps = default(sampling_timesteps, timesteps)
        self.is_ddim_sampling = self.sampling_timesteps < timesteps
        self.ddim_sampling_eta = ddim_sampling_eta

        register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))
        register_buffer('betas', betas)
        register_buffer('alphas_cumprod', alphas_cumprod)
        register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
        register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
        register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
        register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
        register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
        register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))

        posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
        register_buffer('posterior_variance', posterior_variance)
        register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
        register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
        register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))

        self.offset_noise_strength = offset_noise_strength
        snr = alphas_cumprod / (1 - alphas_cumprod)
        maybe_clipped_snr = snr.clone()
        if min_snr_loss_weight:
            maybe_clipped_snr.clamp_(max = min_snr_gamma)

        if objective == 'pred_noise':
            register_buffer('loss_weight', maybe_clipped_snr / snr)
        elif objective == 'pred_x0':
            register_buffer('loss_weight', maybe_clipped_snr)
        elif objective == 'pred_v':
            register_buffer('loss_weight', maybe_clipped_snr / (snr + 1))

        self.normalize = normalize_to_neg_one_to_one if auto_normalize else identity
        self.unnormalize = unnormalize_to_zero_to_one if auto_normalize else identity

    @property
    def device(self):
        return self.betas.device

    def predict_start_from_noise(self, x_t, t, noise):
        return (
            extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
            extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
        )

    def predict_noise_from_start(self, x_t, t, x0):
        return (
            (extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \
            extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
        )

    def predict_v(self, x_start, t, noise):
        return (
            extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise -
            extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start
        )

    def predict_start_from_v(self, x_t, t, v):
        return (
            extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
            extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
        )

    def q_posterior(self, x_start, x_t, t):
        posterior_mean = (
            extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
            extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
        )
        posterior_variance = extract(self.posterior_variance, t, x_t.shape)
        posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
        return posterior_mean, posterior_variance, posterior_log_variance_clipped

    def model_predictions(self, x, t, x_self_cond = None, clip_x_start = False, rederive_pred_noise = False):
        _, c, _, _ = x.shape
        model_output = self.trunk(x, t, x_self_cond)
        maybe_clip = partial(torch.clamp, min = -1., max = 1.) if clip_x_start else identity

        if self.objective == 'pred_noise':
            pred_noise = model_output
            x_start = self.predict_start_from_noise(x[:,c//2:,:,:], t, pred_noise)
            x_start = maybe_clip(x_start)
            if clip_x_start and rederive_pred_noise:
                pred_noise = self.predict_noise_from_start(x[:,c//2:,:,:], t, x_start)

        elif self.objective == 'pred_x0':
            x_start = model_output
            x_start = maybe_clip(x_start)
            pred_noise = self.predict_noise_from_start(x[:,c//2:,:,:], t, x_start)

        elif self.objective == 'pred_v':
            v = model_output
            x_start = self.predict_start_from_v(x[:,c//2:,:,:], t, v)
            x_start = maybe_clip(x_start)
            pred_noise = self.predict_noise_from_start(x[:,c//2:,:,:], t, x_start)

        return ModelPrediction(pred_noise, x_start)

    def p_mean_variance(self, x, t, x_self_cond = None, clip_denoised = True):
        _, c, _, _ = x.shape
        preds = self.model_predictions(x, t, x_self_cond)
        x_start = preds.pred_x_start
        if clip_denoised:
            x_start.clamp_(-1., 1.)
        model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start = x_start, x_t = x[:,c//2:,:,:], t = t)
        return model_mean, posterior_variance, posterior_log_variance, x_start

    @torch.inference_mode()
    def p_sample(self, x, t: int, cond = None, x_self_cond = None):
        b, *_, device = *x.shape, self.device
        batched_times = torch.full((b,), t, device = device, dtype = torch.long)
        x = torch.cat((cond, x), dim = 1) if exists(cond) else x
        model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, x_self_cond = x_self_cond, clip_denoised = True)
        noise = torch.randn_like(x) if t > 0 else 0.
        pred_img = model_mean + (0.5 * model_log_variance).exp() * noise
        return pred_img, x_start

    @torch.inference_mode()
    def p_sample_loop(self, shape, cond = None, return_all_timesteps = False):
        batch, device = shape[0], self.device
        img = torch.randn(shape, device = device)
        imgs = [img]
        x_start = None

        for t in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps, disable = True):
            self_cond = x_start if self.self_condition else None
            img, x_start = self.p_sample(img, t, cond = cond, x_self_cond = self_cond)
            imgs.append(img)

        ret = img if not return_all_timesteps else torch.stack(imgs, dim = 1)
        ret = self.unnormalize(ret)
        return ret

    @torch.inference_mode()
    def ddim_sample(self, shape, cond = None, return_all_timesteps = False):
        batch, device, total_timesteps, sampling_timesteps, eta = shape[0], self.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta

        times = torch.linspace(-1, total_timesteps - 1, steps = sampling_timesteps + 1)
        times = list(reversed(times.int().tolist()))
        time_pairs = list(zip(times[:-1], times[1:]))

        img = torch.randn(shape, device = device)
        imgs = [img]
        x_start = None

        for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step', disable = True):
            time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
            self_cond = x_start if self.self_condition else None
            
            img_in = torch.cat((cond, img), dim = 1) if exists(cond) else img
            pred_noise, x_start, *_ = self.model_predictions(img_in, time_cond, self_cond, clip_x_start = True, rederive_pred_noise = True)

            if time_next < 0:
                img = x_start
                imgs.append(img)
                continue

            alpha = self.alphas_cumprod[time]
            alpha_next = self.alphas_cumprod[time_next]
            sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
            c = (1 - alpha_next - sigma ** 2).sqrt()
            noise = torch.randn_like(x_start)
            img = x_start * alpha_next.sqrt() + c * pred_noise + sigma * noise
            imgs.append(img)

        ret = img if not return_all_timesteps else torch.stack(imgs, dim = 1)
        ret = self.unnormalize(ret)
        return ret

    @torch.inference_mode()
    def sample(self, batch_size = 16, cond = None, return_all_timesteps = False):
        (h, w), channels = self.image_size, self.channels
        sample_fn = self.p_sample_loop if not self.is_ddim_sampling else self.ddim_sample
        cond = self.normalize(cond) if exists(cond) else None
        
        # chunked sampling
        total_samples = cond.shape[0] if exists(cond) else batch_size
        num_chunks = math.ceil(total_samples / batch_size)
        all_samples = []

        for i in range(num_chunks):
            start_idx = i * batch_size
            end_idx = min((i + 1) * batch_size, total_samples)
            current_batch_size = end_idx - start_idx
            
            current_cond = cond[start_idx:end_idx] if exists(cond) else None
            
            # if cond is not None and we're chunking, we pass the chunk slice
            # if cond is None, we just ask for current_batch_size random samples
            
            batch_shape = (current_batch_size, channels//2, h, w)
            samples = sample_fn(batch_shape, cond = current_cond, return_all_timesteps = return_all_timesteps)
            all_samples.append(samples)

        if return_all_timesteps:
            return torch.cat(all_samples, dim=0) # stack along batch dim
        else:
            return torch.cat(all_samples, dim=0)

    @autocast('cuda', enabled = False)
    def q_sample(self, x_start, t, noise = None):
        noise = default(noise, lambda: torch.randn_like(x_start))
        return (
            extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
            extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
        )

    def p_losses(self, stack, t, noise = None, offset_noise_strength = None):
        b, f, c, h, w = stack.shape
        cond = stack[:, 0]
        x_start = stack[:, 1]
        noise = default(noise, lambda: torch.randn_like(x_start))

        offset_noise_strength = default(offset_noise_strength, self.offset_noise_strength)
        if offset_noise_strength > 0.:
            offset_noise = torch.randn(x_start.shape[:2], device = self.device)
            noise += offset_noise_strength * rearrange(offset_noise, 'b c -> b c 1 1')

        x = self.q_sample(x_start = x_start, t = t, noise = noise)
        x = torch.cat((cond, x), dim = 1)

        x_self_cond = None
        if self.self_condition and random() < 0.5:
            with torch.no_grad():
                x_self_cond = self.model_predictions(x, t).pred_x_start
                x_self_cond.detach_()

        model_out = self.trunk(x, t, x_self_cond)
        
        if self.objective == 'pred_noise':
            target = noise
        elif self.objective == 'pred_x0':
            target = x_start
        elif self.objective == 'pred_v':
            v = self.predict_v(x_start, t, noise)
            target = v
        else:
            raise ValueError(f'unknown objective {self.objective}')

        loss = F.mse_loss(model_out, target, reduction = 'none')
        loss = reduce(loss, 'b ... -> b', 'mean')
        loss = loss * extract(self.loss_weight, t, loss.shape)
        return loss.mean()

    def forward(self, img, *args, **kwargs):
        b, f, c, h, w, device, img_size, = *img.shape, img.device, self.image_size
        assert h == img_size[0] and w == img_size[1], f'height and width of image must be {img_size}'
        assert f == 2, f'input image must have 3 frames, got {f}'
        assert c == 2, f'input image must have 2 channels, got {c}'
        t = torch.randint(0, self.num_timesteps, (b,), device=device).long()
        img = self.normalize(img)
        return self.p_losses(img, t, *args, **kwargs)
    
class GaussianDiffusionWithControl(GaussianDiffusion):
    def __init__(self, *args, control_net: ControlNet, gamma=1.0, w_t=1.0, **kwargs):
        super().__init__(*args,**kwargs)
        self.control_net=control_net
        self.gamma=gamma
        self.w_t=w_t

    def ddim_sample_control(
        self, shape, cond=None, obs=None, mask=None, dt=None,
        true_frames=None, return_all_timesteps=False, config=None
    ):
        B = shape[0]
        dev = self.device
        T = self.num_timesteps
        S = self.sampling_timesteps
        eta = self.ddim_sampling_eta

        if cond is not None: cond = self.normalize(cond)
        if obs is not None: obs = self.normalize(obs)
        if true_frames is not None: true_frames = self.normalize(true_frames)

        times = torch.linspace(-1, T-1, steps=S+1, device=dev)
        times = list(reversed(times.int().tolist()))
        pairs = list(zip(times[:-1], times[1:]))

        max_dt = int(dt.max().item()) if dt is not None else 1
        tot_t = torch.zeros((), device=dev, requires_grad=True)
        cnt_t = torch.zeros((), device=dev)
        video = []
        curr_cond = cond
        dt_curr = dt.clone()

        for frame_i in range(max_dt):
            last_u = None
            img = torch.randn(shape, device=dev).detach().requires_grad_()
            x_start = None
            dt_curr = dt_curr - 1
            
            arrived = (dt_curr[0] == 0).nonzero(as_tuple=True)[0]
            if arrived.numel() > 0:
                idx = arrived[0].item()
            else:
                future = dt_curr[0].clone()
                future[future <= 0] = 1e9
                idx = future.argmin().item()

            obs_for_control  = obs[:, idx:idx+1]
            mask_for_control = mask[:, idx:idx+1]
            dt_for_control   = dt_curr[:, idx:idx+1]
            dt_active = int(dt_for_control[0,0].item())

            x_future_proxy = None
            if (config is not None) and str(config.obs_mode).startswith('mask') and dt_active > 0:
                with torch.no_grad():
                    probe = curr_cond.clone()
                    for _ in range(dt_active):
                        t_now = torch.tensor([self.num_timesteps - 1], device=dev, dtype=torch.long)
                        inp_pre = torch.cat([probe, torch.randn_like(probe)], dim=1)
                        _, x0_pre = self.model_predictions(inp_pre, t_now, None, True, True)
                        probe = x0_pre
                    x_future_proxy = probe.detach()

            for sub_i, (t_i, t_next) in enumerate(pairs):
                t_idx = torch.full((B,), t_i, dtype=torch.long, device=dev)

                def unet_forward(inp, t_idx, self_cond, clip_x_start, rederive_pred_noise):
                    return self.model_predictions(inp, t_idx, self_cond, clip_x_start, rederive_pred_noise)

                def control_forward(x_d, x_cond, obs_like, mask_like, dt, max_dt, lu, snr, fi):
                    return self.control_net(x_d, x_cond, obs_like, mask_like, dt, max_dt, lu, snr, fi)

                if (config is not None) and str(config.obs_mode).startswith('mask'):
                    m_flat = mask_for_control[:, 0].float()
                    y_flat = obs_for_control[:, 0]
                    if dt_active == 0:
                        inp_pre = torch.cat([curr_cond, img], dim=1)
                        _, x_start_pre, *_ = checkpoint(unet_forward, inp_pre, t_idx, x_start if self.self_condition else None, True, True)
                        x_tweedie_now = x_start_pre.detach()
                        y_tilde = m_flat * y_flat + (1.0 - m_flat) * x_tweedie_now
                    else:
                        y_tilde = m_flat * y_flat + (1.0 - m_flat) * x_future_proxy
                    obs_like_for_control  = y_tilde.unsqueeze(1)
                    mask_like_for_control = mask_for_control
                else:
                    kernel = config.downsample_factor
                    y_low = F.avg_pool2d(obs_for_control[:, 0], kernel_size=kernel, stride=kernel)
                    if dt_active == 0:
                        inp_pre = torch.cat([curr_cond, img], dim=1)
                        _, x_start_pre, *_ = checkpoint(unet_forward, inp_pre, t_idx, x_start if self.self_condition else None, True, True)
                        x_ref = x_start_pre.detach()
                    else:
                        x_ref = x_future_proxy
                    x_low = F.avg_pool2d(x_ref, kernel_size=kernel, stride=kernel)
                    resid_low = y_low - x_low
                    resid_up  = F.interpolate(resid_low, size=(x_ref.shape[-2], x_ref.shape[-1]), mode='nearest')
                    x_proj    = x_ref + resid_up
                    obs_like_for_control  = x_proj.unsqueeze(1)
                    mask_like_for_control = mask_for_control

                ctrl_prev = last_u if last_u is not None else torch.zeros_like(img)
                u = control_forward(
                        img, curr_cond, obs_like_for_control, mask_like_for_control,
                        dt_for_control, max_dt, ctrl_prev,
                        self.alphas_cumprod[t_i] / (1 - self.alphas_cumprod[t_i] + 1e-8), frame_i
                    )
                last_u = u
                img_ctrl = img + self.gamma * u
                inp_c = torch.cat([curr_cond, img_ctrl], dim=1)
                pred_noise_c, x_start_c, *_ = checkpoint(unet_forward, inp_c, t_idx, x_start if self.self_condition else None, True, True)

                if true_frames is not None and (dt_curr==0).any():
                    m = (dt_curr[0] == 0).nonzero(as_tuple=True)[0].item()
                    gt_mid  = true_frames[:, m]
                    if config.obs_mode == 'downsample':
                        kernel = config.downsample_factor
                        ds_pred = F.avg_pool2d(x_start_c, kernel_size=kernel, stride=kernel)
                        ds_gt   = F.avg_pool2d(gt_mid,     kernel_size=kernel, stride=kernel)
                        diff2_mid = (ds_pred - ds_gt).pow(2)
                        loss_mid = self.w_t * diff2_mid.flatten(1).mean(-1)
                    else:
                        mask_mid = mask[:, m]
                        diff2_mid = (x_start_c - gt_mid).pow(2) * mask_mid
                        loss_mid = self.w_t * diff2_mid.flatten(1).sum(-1) / mask_mid.flatten(1).sum(-1).clamp_min(1.0)
                    tot_t = tot_t + loss_mid.sum()
                    cnt_t = cnt_t + loss_mid.numel()

                x_start = x_start_c

                if t_next < 0:
                    img = x_start
                else:
                    a_i = self.alphas_cumprod[t_i]
                    a_n = self.alphas_cumprod[t_next]
                    sigma = eta * ((1 - a_i / a_n) * (1 - a_n) / (1 - a_i)).sqrt()
                    c = (1 - a_n - sigma.pow(2)).sqrt()
                    noise = torch.randn_like(x_start)
                    img = x_start * a_n.sqrt() + c * pred_noise_c + sigma * noise

            video.append(self.unnormalize(img))
            curr_cond = img
            keep2d = (dt_curr > 0)
            keep1d = keep2d[0]
            obs         = obs[:, keep1d]
            mask        = mask[:, keep1d]
            if true_frames is not None:
                true_frames = true_frames[:, keep1d]
            dt_curr     = dt_curr[:, keep1d]

        video = torch.stack(video, dim=1)
        losses = {'terminal': tot_t / cnt_t if cnt_t > 0 else torch.tensor(0.0, device=dev)}

        if return_all_timesteps:
            return video, losses
        return video[:, -1], losses

class TrajectoryDataset(Dataset):
    def __init__(self, file: Path, stats: Path, dt_max: int = 10):
        super().__init__()
        with h5py.File(file, mode='r') as f:
            self.data = f['x'][:]
        self.dt_max = dt_max
        stats = torch.load(stats)
        self.channel_min = stats['min'].view(1, -1, 1, 1)
        self.channel_max = stats['max'].view(1, -1, 1, 1)

    def __len__(self) -> int:
        return len(self.data)
    
    def __getitem__(self, i: int) -> Tuple[Tensor, int]:
        seq = torch.from_numpy(self.data[i])
        seq = (seq - self.channel_min) / (self.channel_max - self.channel_min)
        T = seq.shape[0]
        dt = torch.randint(1, self.dt_max + 1, ()).item()
        max_start = T - 1 - dt
        if max_start < 1: raise ValueError(f"Sequence length {T} too short for dt={dt}")
        j = torch.randint(0, max_start + 1, ()).item()
        f0  = seq[j]
        f1  = seq[j + 1]
        triple = torch.stack([f0, f1], dim=0)
        return triple
    
class TrajectoryWindowDataset(Dataset):
    def __init__(self, file: Path, stats: Path, window: int):
        super().__init__()
        with h5py.File(file, mode='r') as f:
            self.data = f['x'][:]
        self.window = window
        stats = torch.load(stats)
        self.channel_min = stats['min'].view(1, -1, 1, 1)
        self.channel_max = stats['max'].view(1, -1, 1, 1)

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, i: int) -> Tensor:
        seq = torch.from_numpy(self.data[i])
        seq = (seq - self.channel_min) / (self.channel_max - self.channel_min)
        T = seq.shape[0]
        if T < self.window: raise ValueError(f"Sequence length {T} too short for window={self.window}")
        start = 0
        window_seq = seq[start : start + self.window]
        return window_seq

# trainer class

class Trainer:
    def __init__(
        self, diffusion_model, *, train_batch_size = 16, gradient_accumulate_every = 1,
        train_lr = 1e-4, train_num_steps = 100000, ema_update_every = 10, ema_decay = 0.995,
        adam_betas = (0.9, 0.99), save_and_sample_every = 1000, num_samples = 25,
        amp = False, mixed_precision_type = 'fp16', split_batches = True, max_grad_norm = 1.
    ):
        super().__init__()
        self.accelerator = Accelerator(
            split_batches = split_batches,
            mixed_precision = mixed_precision_type if amp else 'no'
        )

        self.model = diffusion_model
        self.channels = diffusion_model.channels
        self.num_samples = num_samples
        self.save_and_sample_every = save_and_sample_every
        self.batch_size = train_batch_size
        self.gradient_accumulate_every = gradient_accumulate_every
        self.train_num_steps = train_num_steps
        self.image_size = diffusion_model.image_size
        self.max_grad_norm = max_grad_norm
        
        # Hardcoded paths for dataset
        self.ds = TrajectoryDataset(Path('train.h5'), Path('kolm_stats.pt'), dt_max=3)
        self.val_ds = TrajectoryDataset(Path('valid.h5'), Path('kolm_stats.pt'), dt_max=3)
        self.test_ds = TrajectoryWindowDataset(Path('test.h5'), Path('kolm_stats.pt'), window=180)

        dl = DataLoader(self.ds, batch_size = train_batch_size, shuffle = True, pin_memory = True, num_workers = 2)
        val_dl = DataLoader(self.val_ds, batch_size = train_batch_size, shuffle = False, pin_memory = True, num_workers = 2)
        test_dl = DataLoader(self.test_ds, batch_size = train_batch_size, shuffle = False, pin_memory = True, num_workers = 2)

        dl = self.accelerator.prepare(dl)
        self.dl = cycle(dl)
        val_dl = self.accelerator.prepare(val_dl)
        self.val_dl = cycle(val_dl)
        test_dl = self.accelerator.prepare(test_dl)
        self.test_dl = cycle(test_dl)

        self.opt = Adam(diffusion_model.parameters(), lr = train_lr, betas = adam_betas)

        if self.accelerator.is_main_process:
            self.ema = EMA(diffusion_model, beta = ema_decay, update_every = ema_update_every)
            self.ema.to(self.device)

        self.step = 0
        for p in self.model.parameters():
            p.requires_grad = False
        self.model, self.opt = self.accelerator.prepare(self.model, self.opt)

    @property
    def device(self):
        return self.accelerator.device

    def load(self, milestone):
        # Minimal load for compatibility with ControlTrainer needing a backbone
        # Assuming loading from a path relative to current dir or absolute
        # Path hardcoded as per original logic or passed args, here simplified
        path = Path(f'./results/model-{milestone}.pt')
        if not path.exists():
             print(f"Checkpoint {path} not found, skipping load.")
             return

        data = torch.load(str(path), map_location=self.device, weights_only=True)
        model = self.accelerator.unwrap_model(self.model)
        model.load_state_dict(data['model'])
        self.step = data['step']
        self.opt.load_state_dict(data['opt'])
        if self.accelerator.is_main_process:
            self.ema.load_state_dict(data["ema"])
        if exists(self.accelerator.scaler) and exists(data['scaler']):
            self.accelerator.scaler.load_state_dict(data['scaler'])

    def train(self):
        accelerator = self.accelerator
        device      = accelerator.device
        pbar = tqdm(initial=self.step, total=self.train_num_steps, disable=not accelerator.is_main_process)

        while self.step < self.train_num_steps:
            self.model.train()
            running_loss = 0.0

            for _ in range(self.gradient_accumulate_every):
                data = next(self.dl)
                data = data.to(device)

                with accelerator.autocast():
                    loss = self.model(data)
                    loss = loss / self.gradient_accumulate_every

                accelerator.backward(loss)
                running_loss += loss.detach().item()

            accelerator.wait_for_everyone()
            accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
            self.opt.step()
            self.opt.zero_grad()

            if accelerator.is_main_process:
                self.ema.update()
                if self.step and self.step % self.save_and_sample_every == 0:
                    self._validate_and_log(device, accelerator)

            self.step += 1
            pbar.set_description(f'loss: {running_loss:.4f}')
            pbar.update(1)

        pbar.close()
        accelerator.print('Training complete')

    def _validate_and_log(self, device, accelerator):
        data = next(self.val_dl)
        cond, target = [x.to(device) for x in (data[:,0], data[:,1])]
        
        self.ema.ema_model.eval()
        with torch.inference_mode(), accelerator.autocast():
            pred = self.ema.ema_model.sample(batch_size=cond.shape[0], cond=cond)

        rmse = torch.sqrt(((pred - target)**2).mean()).item()
        accelerator.print(f'Step {self.step}: Validation RMSE: {rmse}')

class ControlTrainer(Trainer):
    def __init__(
        self, diffusion_model: GaussianDiffusionWithControl, control_net: ControlNet,
        train_h5: str, val_h5: str, config: ControlConfig, ctrl_lr: float = 1e-4,
        num_workers: int = 4, **kwargs
    ):
        super().__init__(diffusion_model, **kwargs)
        if 'pretrained_milestone' in kwargs:
            self.load(kwargs['pretrained_milestone'])

        self.control_net = control_net
        self.control_opt = torch.optim.Adam(control_net.parameters(), lr=ctrl_lr)
        self.config = config

        train_ds = ControlBatchIterableDataset(h5_file=train_h5, config=config, batch_size=self.batch_size, shuffle=True)
        val_ds   = ControlBatchIterableDataset(h5_file=val_h5, config=config, batch_size=self.batch_size, shuffle=False)
        train_dl = DataLoader(train_ds, batch_size=None, pin_memory=True, num_workers=num_workers)
        val_dl   = DataLoader(val_ds,   batch_size=None, pin_memory=True, num_workers=num_workers)
        self.control_net, self.control_opt, self.control_dl, self.control_val_dl = self.accelerator.prepare(
            self.control_net, self.control_opt, train_dl, val_dl
        )
        self.control_dl = cycle(self.control_dl)
        self.control_val_dl = cycle(self.control_val_dl)
        
        for p in self.control_net.parameters():
            p.requires_grad = True

    def train_controlnet(self):
        self.control_net.train()
        pbar = tqdm(range(self.train_num_steps), disable=not self.accelerator.is_main_process)
        
        for _ in pbar:
            prev, obs, mask, dt, true_frames = next(self.control_dl)
            prev, obs, mask, dt, true_frames = [t.to(self.device) for t in (prev,obs,mask,dt,true_frames)]
            
            with torch.enable_grad(), self.accelerator.autocast():
                video, losses = self.model.ddim_sample_control(
                    shape=prev.shape, cond=prev, obs=obs, mask=mask, dt=dt,
                    true_frames=true_frames, return_all_timesteps=True, config=self.config
                )
                total_loss = losses['terminal']
                
            self.accelerator.backward(total_loss)
            self.accelerator.wait_for_everyone()
            self.accelerator.clip_grad_norm_(self.control_net.parameters(), self.max_grad_norm)
            self.control_opt.step()
            self.control_opt.zero_grad()
            
            self.step += 1
            pbar.set_description(f'loss: {total_loss.item():.4f}')

            if self.accelerator.is_main_process:
                if self.step % self.save_and_sample_every == 0:
                    self.validate_controlnet()

    @torch.inference_mode()
    def validate_controlnet(self, num_batches: int = 10) -> float:
        self.control_net.eval()
        total_term_ctrl = 0.0
        total_term_unctrl = 0.0
        count = 0

        for i, (prev, obs, mask, dt, true_frames) in enumerate(self.control_val_dl):
            if i >= num_batches: break
            prev, obs, mask, dt, true_frames = [t.to(self.device) for t in (prev, obs, mask, dt, true_frames)]
            B = prev.shape[0]
            dt_max = int(dt.max().item())

            with self.accelerator.autocast():
                video_ctrl, _ = self.model.ddim_sample_control(
                    shape=prev.shape, cond=prev, obs=obs, mask=mask, dt=dt,
                    true_frames=true_frames, return_all_timesteps=True, config=self.config
                )
            
            idx = (dt - 1)
            pred_at_obs_ctrl = torch.gather(video_ctrl, 1, idx[:, :, None, None, None].expand(-1, -1, prev.shape[1], prev.shape[2], prev.shape[3]))
            diff_ctrl = pred_at_obs_ctrl - true_frames
            obs_mse_ctrl = (diff_ctrl.pow(2) * mask).sum() / mask.sum().clamp_min(1.0)

            unctrl = torch.empty((B, dt_max, prev.shape[1], prev.shape[2], prev.shape[3]), device=self.device)
            cond_u = prev
            self.ema.ema_model.eval()
            for t in range(dt_max):
                with self.accelerator.autocast():
                    sampled = self.ema.ema_model.sample(batch_size=B, cond=cond_u)
                unctrl[:, t] = sampled
                cond_u = sampled

            pred_at_obs_unctrl = torch.gather(unctrl, 1, idx[:, :, None, None, None].expand(-1, -1, prev.shape[1], prev.shape[2], prev.shape[3]))
            diff_unctrl = pred_at_obs_unctrl - true_frames
            obs_mse_unctrl = (diff_unctrl.pow(2) * mask).sum() / mask.sum().clamp_min(1.0)

            total_term_ctrl += obs_mse_ctrl.item()
            total_term_unctrl += obs_mse_unctrl.item()
            count += 1

        avg_ctrl = total_term_ctrl / max(count, 1)
        avg_unctrl = total_term_unctrl / max(count, 1)
        self.accelerator.print(f'Step {self.step}: Ctrl MSE: {avg_ctrl}, Unctrl MSE: {avg_unctrl}')
        return avg_ctrl

    def load_diffusion(self, milestone: int):
        # Simplified loader matching the minimal Trainer.load
        path = Path(f'./results/model-{milestone}.pt')
        if not path.exists(): return
        ckpt = torch.load(str(path), map_location=self.device, weights_only=True)
        model = self.accelerator.unwrap_model(self.model)
        model.load_state_dict(ckpt['model'], strict=False)
        if self.accelerator.is_main_process:
            self.ema.load_state_dict(ckpt['ema'], strict=False)
        if exists(self.accelerator.scaler) and exists(ckpt['scaler']):
            self.accelerator.scaler.load_state_dict(ckpt['scaler'])
