from os import PathLike
from pathlib import Path
from typing import Mapping, Optional, Union

import torch
import torch.nn as nn
import yaml
from safetensors.torch import load_file as safe_load_file

from dae.utils.generic_utils import ModuleBuildArg
from dae.utils.train_utils import init_weights

from ...flow.flow_samplers import FLOW_SAMPLERS, FMEulerSampler
from ...flow.flow_trainer import FLOW_TRAINERS, FlowMatchingTrainer
from ...registers import AUTOENCODERS, DECODERS, ENCODERS
from ..blocks.diag_gauss import DiagonalGaussianDistribution
from ..dataclasses import TrainStepResult


@AUTOENCODERS.register("ssdd")
class SSDD(nn.Module):
    def __init__(
        self,
        encoder: ModuleBuildArg | nn.Module = "vq_encoder",
        decoder: ModuleBuildArg | nn.Module = "uvit",
        fm_trainer: ModuleBuildArg | FlowMatchingTrainer = "flow_matching",
        fm_sampler: ModuleBuildArg | FMEulerSampler = "euler",
        init: Optional[Mapping] = None,
    ):
        super().__init__()

        ### Submodules ###

        self.encoder = ENCODERS.build(encoder, required=True)
        self.decoder = DECODERS.build(decoder, required=True)

        ### Flow-matching ###

        self.fm_trainer = FLOW_TRAINERS.build(fm_trainer)
        self.fm_sampler = FLOW_SAMPLERS.build(fm_sampler)

        ## Weights init ###
        self.init_weights(**(init or {}))

    def init_weights(self, method="kaiming_normal", **kwargs):
        init_weights(self, method=method, **kwargs)

    def encode(self, x) -> DiagonalGaussianDistribution:
        return self.encoder(x)

    def decode(self, z: torch.Tensor, steps: Optional[int] = None, noise: Optional[torch.Tensor] = None) -> torch.Tensor:
        fn_kwargs = {"z": z}

        B, _, zH, zW = z.shape
        H, W = zH * self.encoder.patch_size, zW * self.encoder.patch_size

        ret = self.fm_sampler.sample(
            self.decoder,
            self.fm_trainer,
            shape=(B, self.decoder.out_dim, H, W),
            n_steps=steps,
            fn_kwargs=fn_kwargs,
            noise=noise,
        )

        return ret

    def forward(self, gt_x: torch.Tensor, steps: Optional[int] = None, noise: Optional[torch.Tensor] = None, z: Optional[torch.Tensor] = None) -> Union[torch.Tensor, TrainStepResult]:
        # Encoder
        encoded = None
        if z is None:
            encoded = self.encode(gt_x)
            z = encoded.sample() if self.training else encoded.mode()

        # Decoder
        if not self.training:
            return self.decode(z, steps=steps, noise=noise)
        else:
            t = self.fm_trainer.sample_t(gt_x.shape[0], device=gt_x.device)

            # Use decoder to get a diffusion reconstruction loss
            diff_loss, (x_t, noise, noise_t, v_pred) = self.fm_trainer.loss(self.decoder, x=gt_x, t=t, fn_kwargs={"z": z})

            # Compute auxiliary losses
            x0_pred = self.fm_trainer.step(x_t, v_pred, noise_t)
            losses = {"diffusion": diff_loss}
            if encoded:
                losses["kl"] = encoded.kl().mean()

            return TrainStepResult(
                x0_gt=gt_x,
                x0_pred=x0_pred,
                xt=x_t,
                t=t,
                z=z,
                noise=noise,
                losses=losses,
            )

    def get_last_layer_weight(self):
        return self.decoder.conv_out.weight

    ### Loading / Checkpointing ###

    def load(self, weights: Union[str, Path, Mapping], strict: bool = True, freeze=False, eval=None):
        if not isinstance(weights, Mapping):
            weights = safe_load_file(weights)
        self.load_state_dict(weights, strict=strict)

        if eval or (eval is None and freeze):
            self.eval()
        if freeze:
            self.requires_grad_(False)
        return self

    @classmethod
    def build(cls, config, checkpoint=None, freeze=True, eval=True):
        """Build the model from a config name."""
        if isinstance(config, (str, PathLike)):
            with open(config, "r") as yaml_file:
                model_args = yaml.safe_load(yaml_file)["model"]
        elif isinstance(config, Mapping):
            model_args = config
        else:
            raise ValueError(f"Invalid config type: {type(config)}. Expected model size, path, or a mapping.")

        model = cls(**model_args)

        if checkpoint:
            model.load(checkpoint)
        if eval:
            model.eval()
        if freeze:
            model.requires_grad_(False)
        return model
