import torch
import torch.nn as nn

from .layers import ConditionAdaLNODE, ConditionAdaLN, ConditionAdaLNTransformer, ConditionAdaLNv2, ConditionAdaLNv3, ConditionAdaLNv4
from ...base.mixin import PropertyMixIn
from ...diffusion.edm import EDM


class ArbitraryConditionModel(PropertyMixIn, nn.Module):

    def __init__(self, input_dim: int, condition_dim: int, n_columns: int, d_model: int, d_time: int, d_cond: int,
                 ode=True, **kwargs):
        super().__init__()
        self.input_dim = input_dim
        self.condition_dim = condition_dim
        self.n_columns = n_columns
        self.diffloss = EDM(
            ConditionAdaLNODE(
                in_channels=input_dim,
                ode_channels=d_model,
                time_channels=d_time,
                ode_time_channels=min(32, d_time),
                condition_in_channels=condition_dim + n_columns,  # input dim + mask dim
                condition_channels=d_cond
            ) if ode else ConditionAdaLN(
                in_channels=input_dim,
                ode_channels=d_model,
                time_channels=d_time,
                condition_in_channels=condition_dim + n_columns,  # input dim + mask dim
                condition_channels=d_cond
            ),
        )

    def _get_condition(self, bsz, condition=None, observed_column=None):
        if condition is None:
            condition = torch.zeros(bsz, self.condition_dim, device=self.device, dtype=self.dtype)
        if observed_column is None:
            observed_column = torch.zeros(bsz, self.n_columns, device=self.device, dtype=self.dtype)
        return torch.cat([condition, observed_column.to(condition)], dim=1)

    def forward(self, target, condition=None, observed_column=None):
        bsz = len(target)
        condition = self._get_condition(bsz, condition, observed_column)
        return self.diffloss(target, condition=condition)

    def sample(self, *, n=None, condition=None, observed_column=None, latents=None, temperature=0., cfg=1.,
               condition_u=None, observed_column_u=None, num_steps=50):
        n = len(condition) if condition is not None else n
        if latents is None:
            latents = torch.randn(len(condition), self.input_dim, device=condition.device,
                                  dtype=torch.float64)

        condition = self._get_condition(n, condition, observed_column)
        condition_u = self._get_condition(n, condition_u, observed_column_u)
        latents = latents.view(n, self.input_dim)
        imputed = self.diffloss.sample(latents,
                                       condition=condition,
                                       condition_u=condition_u,
                                       num_steps=num_steps, S_churn=temperature, cfg=cfg)
        return imputed


class ConditionAdaLNTransformerModel(PropertyMixIn, nn.Module):

    def __init__(
            self,
            input_dim: int,
            n_columns: int,
            d_model: int,
            d_time: int,
            d_cond: int,
            *,
            numerical_dim: int,
            n_categories_per_columns: list[int],
            transformer_layers: int = 2,
            transformer_nheads: int | None = None,
            use_mask_token: bool = True,
            cond_pool: str = 'mean'
    ):
        super().__init__()
        self.input_dim = input_dim
        self.condition_dim = input_dim  # condition values length == column_dim
        self.n_columns = n_columns
        self.diffloss = EDM(
            ConditionAdaLNTransformer(
                in_channels=input_dim,
                d_model=d_model,
                time_channels=d_time,
                condition_channels=d_cond,
                numerical_dim=numerical_dim,
                n_categories_per_columns=n_categories_per_columns,
                transformer_layers=transformer_layers,
                transformer_nheads=transformer_nheads,
                use_mask_token=use_mask_token,
                cond_pool=cond_pool,
            ),
        )

    def _get_condition_pair(self, bsz, condition=None, observed_column=None):
        if condition is None:
            condition = torch.zeros(bsz, self.condition_dim, device=self.device, dtype=self.dtype)
        if observed_column is None:
            observed_column = torch.zeros(bsz, self.n_columns, device=self.device, dtype=self.dtype)
        return condition, observed_column.to(condition)

    def forward(self, target, condition=None, observed_column=None):
        bsz = len(target)
        condition, observed_column = self._get_condition_pair(bsz, condition, observed_column)
        return self.diffloss(target, condition=condition, observed_column=observed_column)

    def sample(self, *, n=None, condition=None, observed_column=None, latents=None, temperature=0., cfg=1.,
               condition_u=None, observed_column_u=None):
        n = len(condition) if condition is not None else n
        if latents is None:
            device = condition.device if condition is not None else 'cpu'
            latents = torch.randn(n, self.input_dim, device=device, dtype=torch.float64)

        condition, observed_column = self._get_condition_pair(n, condition, observed_column)
        condition_u, observed_column_u = self._get_condition_pair(n, condition_u, observed_column_u)
        latents = latents.view(n, self.input_dim)
        imputed = self.diffloss.sample(
            latents,
            condition=condition,
            observed_column=observed_column,
            condition_u=condition_u,
            observed_column_u=observed_column_u,
            num_steps=50,
            S_churn=temperature,
            cfg=cfg
        )
        return imputed


class ArbitraryConditionModelv2(PropertyMixIn, nn.Module):

    def __init__(
            self,
            input_dim: int,
            condition_dim: int,
            n_columns: int,
            d_model: int,
            d_time: int,
            d_cond: int,
            *,
            numerical_dim: int,
            n_categories_per_columns: list[int],
            hidden_per_col: int = 4,
            **kwargs):
        super().__init__()
        self.input_dim = input_dim
        self.condition_dim = condition_dim
        self.n_columns = n_columns
        self.diffloss = EDM(
            ConditionAdaLNv2(
                in_channels=input_dim,
                ode_channels=d_model,
                time_channels=d_time,
                condition_in_channels=condition_dim + n_columns,
                condition_channels=d_cond,
                numerical_dim=numerical_dim,
                n_categories_per_columns=n_categories_per_columns,
                hidden_per_col=hidden_per_col,
            ),
        )

    def _get_condition(self, bsz, condition=None, observed_column=None):
        if condition is None:
            condition = torch.zeros(bsz, self.condition_dim, device=self.device, dtype=self.dtype)
        if observed_column is None:
            observed_column = torch.zeros(bsz, self.n_columns, device=self.device, dtype=self.dtype)
        return torch.cat([condition, observed_column.to(condition)], dim=1)

    def forward(self, target, condition=None, observed_column=None):
        bsz = len(target)
        condition = self._get_condition(bsz, condition, observed_column)
        return self.diffloss(target, condition=condition)

    def sample(self, *, n=None, condition=None, observed_column=None, latents=None, temperature=0., cfg=1.,
               condition_u=None, observed_column_u=None):
        n = len(condition) if condition is not None else n
        if latents is None:
            latents = torch.randn(len(condition), self.input_dim, device=condition.device,
                                  dtype=torch.float64)

        condition = self._get_condition(n, condition, observed_column)
        condition_u = self._get_condition(n, condition_u, observed_column_u)
        latents = latents.view(n, self.input_dim)
        imputed = self.diffloss.sample(latents,
                                       condition=condition,
                                       condition_u=condition_u,
                                       num_steps=50, S_churn=temperature, cfg=cfg)
        return imputed


class ArbitraryConditionModelv3(PropertyMixIn, nn.Module):

    def __init__(
            self,
            input_dim: int,
            condition_dim: int,
            n_columns: int,
            d_model: int,
            d_time: int,
            d_cond: int,
            *,
            numerical_dim: int,
            n_categories_per_columns: list[int],
            **kwargs):
        super().__init__()
        self.input_dim = input_dim
        self.condition_dim = condition_dim
        self.n_columns = n_columns
        # ConditionAdaLNv3: condition 값 투영 + observed_column modulation
        self.diffloss = EDM(
            ConditionAdaLNv3(
                in_channels=input_dim,
                ode_channels=d_model,
                time_channels=d_time,
                condition_val_in_channels=condition_dim,
                condition_channels=d_cond,
                mask_in_channels=n_columns,
            ),
        )

    def _get_condition_pair(self, bsz, condition=None, observed_column=None):
        if condition is None:
            condition = torch.zeros(bsz, self.condition_dim, device=self.device, dtype=self.dtype)
        if observed_column is None:
            observed_column = torch.zeros(bsz, self.n_columns, device=self.device, dtype=self.dtype)
        return condition, observed_column.to(condition)

    def forward(self, target, condition=None, observed_column=None):
        bsz = len(target)
        condition, observed_column = self._get_condition_pair(bsz, condition, observed_column)
        return self.diffloss(target, condition=condition, observed_column=observed_column)

    def sample(self, *, n=None, condition=None, observed_column=None, latents=None, temperature=0., cfg=1.,
               condition_u=None, observed_column_u=None):
        n = len(condition) if condition is not None else n
        if latents is None:
            device = condition.device if condition is not None else 'cpu'
            latents = torch.randn(n, self.input_dim, device=device, dtype=torch.float64)

        condition, observed_column = self._get_condition_pair(n, condition, observed_column)
        condition_u, observed_column_u = self._get_condition_pair(n, condition_u, observed_column_u)
        latents = latents.view(n, self.input_dim)
        imputed = self.diffloss.sample(
            latents,
            condition=condition,
            observed_column=observed_column,
            condition_u=condition_u,
            observed_column_u=observed_column_u,
            num_steps=50,
            S_churn=temperature,
            cfg=cfg
        )
        return imputed


class ArbitraryConditionModelv4(PropertyMixIn, nn.Module):

    def __init__(
            self,
            input_dim: int,
            condition_dim: int,
            n_columns: int,
            d_model: int,
            d_time: int,
            d_cond: int,
            *,
            numerical_dim: int,
            n_categories_per_columns: list[int],
            mask_mod_with_time: bool = True,
            mod_time: bool = True,
            **kwargs):
        super().__init__()
        self.input_dim = input_dim
        self.condition_dim = condition_dim
        self.n_columns = n_columns
        # v4 backbone: mask+time로 condition 및 time 임베딩 동시 modulation
        self.diffloss = EDM(
            ConditionAdaLNv4(
                in_channels=input_dim,
                ode_channels=d_model,
                time_channels=d_time,
                condition_val_in_channels=condition_dim,
                condition_channels=d_cond,
                mask_in_channels=n_columns,
                mask_mod_with_time=mask_mod_with_time,
                mod_time=mod_time,
            ),
        )

    def _get_condition_pair(self, bsz, condition=None, observed_column=None):
        if condition is None:
            condition = torch.zeros(bsz, self.condition_dim, device=self.device, dtype=self.dtype)
        if observed_column is None:
            observed_column = torch.zeros(bsz, self.n_columns, device=self.device, dtype=self.dtype)
        return condition, observed_column.to(condition)

    def forward(self, target, condition=None, observed_column=None):
        bsz = len(target)
        condition, observed_column = self._get_condition_pair(bsz, condition, observed_column)
        return self.diffloss(target, condition=condition, observed_column=observed_column)

    def sample(self, *, n=None, condition=None, observed_column=None, latents=None, temperature=0., cfg=1.,
               condition_u=None, observed_column_u=None):
        n = len(condition) if condition is not None else n
        if latents is None:
            device = condition.device if condition is not None else 'cpu'
            latents = torch.randn(n, self.input_dim, device=device, dtype=torch.float64)

        condition, observed_column = self._get_condition_pair(n, condition, observed_column)
        condition_u, observed_column_u = self._get_condition_pair(n, condition_u, observed_column_u)
        latents = latents.view(n, self.input_dim)
        imputed = self.diffloss.sample(
            latents,
            condition=condition,
            observed_column=observed_column,
            condition_u=condition_u,
            observed_column_u=observed_column_u,
            num_steps=50,
            S_churn=temperature,
            cfg=cfg
        )
        return imputed
