from dataclasses import dataclass
from typing import Any, Dict, List

import torch

####################################################################
# Data classes
####################################################################


@dataclass
class TrainStepResult:
    x0_gt: torch.Tensor = None
    x0_pred: torch.Tensor = None
    xt: torch.Tensor = None
    t: torch.Tensor = None
    z: torch.Tensor = None
    noise: torch.Tensor = None
    losses: Dict[str, torch.Tensor] = None


@dataclass
class FlowSampleOut:
    x: torch.Tensor
    x_steps: List[torch.Tensor]
    timesteps: torch.Tensor
    conditions: List[Any]

    def concat(self, other):
        assert (((self.timesteps - other.timesteps) ** 2) < 1e-9).all()
        return FlowSampleOut(
            x=torch.cat([self.x, other.x], dim=0),
            x_steps=[torch.cat([x1, x2]) for x1, x2 in zip(self.x_steps, other.x_steps)],
            timesteps=self.timesteps,
            conditions=self.conditions + other.conditions,
        )

    def __getitem__(self, key):
        return FlowSampleOut(
            x=self.x[key],
            x_steps=[x[key] for x in self.x_steps],
            timesteps=self.timesteps,
            conditions=self.conditions[key],
        )

    @torch.compiler.disable(recursive=False)
    def x_apply(self, f):
        return FlowSampleOut(
            x=f(self.x),
            x_steps=[f(x) for x in self.x_steps],
            timesteps=self.timesteps,
            conditions=self.conditions,
        )
