import torch
import torch.nn as nn

from . import register_wrapper
from ..architectures import get_architecture
from ..components import get_component

@register_wrapper("RecurrentModel")
class RecurrentModel(nn.Module):
    """
    Recurrent Neural Network for PDE modeling.

    This model implements a recurrent architecture with:
    1. An encoder that maps input to a higher-dimensional representation
    2. A recurrent core with stochastic depth during training
    3. A decoder that maps the representation back to the output domain

    The stochastic depth is implemented by:
    - Sampling the number of recurrent iterations (k) from a Poisson distribution
    - Using truncated backpropagation through time for efficiency
    - Allowing deterministic depth at inference time

    Args:
        dimension: Spatial dimension of the problem
        in_channels: Number of input channels
        out_channels: Number of output channels
        hidden_channels: Number of channels in hidden representations
        encoder_class: Block class for encoder
        encoder_depth: Number of blocks in encoder
        recurrent_class: Block class for recurrent core
        recurrent_depth: Number of blocks in recurrent core
        recurrent_z_distribution: Config for initial state distribution
        recurrent_k_distribution: Config for iterations distribution
        recurrent_tbptt_steps: Steps for truncated backpropagation
        decoder_class: Block class for decoder
        decoder_depth: Number of blocks in decoder
        skip: Whether to use skip connections
    """
    def __init__(
            self,
            dimension: int,
            in_channels: int, 
            out_channels: int, 
            hidden_channels: int,
            encoder_class: str = "ResidualBlock",
            encoder_depth: int = 2,
            encoder_kwargs: dict[str, object] | None = None,
            recurrent_class: str = "ResidualBlock",
            recurrent_depth: int = 4,
            recurrent_kwargs: dict[str, object] | None = None,
            recurrent_z_distribution: dict[str, object] | None = None,
            recurrent_k_distribution: dict[str, object] | None = None,
            recurrent_tbptt_steps: int = 8,
            decoder_class: str = "ResidualBlock",
            decoder_depth: int = 2,
            decoder_kwargs: dict[str, object] | None = None,
            skip: bool = False,
            **kwargs
        ):
        super().__init__()

        self.recurrent_z_distribution = {
            **{
                "distribution": "normal",
                "z_mean": 0.0,
                "z_std": 1.0
            },
            **(recurrent_z_distribution or {})
        }
        for key, value in self.recurrent_z_distribution.items():
            if key != "distribution":
                self.register_buffer(f'{key}', torch.tensor(value))

        self.recurrent_k_distribution = {
            **{
                "distribution": "poisson",
                "k_bar": 16.0,
                "k_sigma": 0.5
            },
            **(recurrent_k_distribution or {})
        }
        for key, value in self.recurrent_k_distribution.items():
            if key != "distribution":
                self.register_buffer(f'{key}', torch.tensor(value))
    
        self.recurrent_tbptt_steps = recurrent_tbptt_steps
        self.skip = skip

        self.encoder = Encoder(
            dimension=dimension,
            in_channels=in_channels,
            out_channels=hidden_channels,
            encoder_class=encoder_class,
            encoder_depth=encoder_depth,
            **{**kwargs, **(encoder_kwargs or {})}
        )

        try:
            RecurrentClass = get_architecture(recurrent_class)
        except KeyError:
            try:
                RecurrentClass = get_component(recurrent_class)
            except KeyError:
                raise ValueError(f"Recurrent block class {recurrent_class} not found in any registry.")

        recurrent_block = []
        recurrent_block.append(RecurrentClass(
            dimension=dimension,
            in_channels=hidden_channels,
            out_channels=hidden_channels,
            **{**kwargs, **(recurrent_kwargs or {})}
        ))
        for _ in range(recurrent_depth - 1):
            recurrent_block.append(RecurrentClass(
                dimension=dimension,
                in_channels=hidden_channels,
                out_channels=hidden_channels,
                **{**kwargs, **(recurrent_kwargs or {})}
            ))
        # #####################
        # # Concat Everywhere #
        # #####################
        # for _ in range(recurrent_depth):
        #     recurrent_block.append(RecurrentClass(
        #         dimension=dimension,
        #         in_channels=hidden_channels * 2,
        #         out_channels=hidden_channels,
        #         **{**kwargs, **(recurrent_kwargs or {})}
        #     ))
        # #####################
        # # Concat Everywhere #
        # #####################
        self.recurrent_block = nn.ModuleList(recurrent_block)

        self.decoder = Decoder(
            dimension=dimension,
            in_channels=hidden_channels if not skip else hidden_channels * 2,
            out_channels=out_channels,
            decoder_class=decoder_class,
            decoder_depth=decoder_depth,
            **{**kwargs, **(decoder_kwargs or {})}
        )
    
    def _sample_z(self, x):
        """
        Sample initial state from normal distribution.

        Args:
            x: Input tensor to match device and shape
        """
        if self.recurrent_z_distribution["distribution"] == "normal":
            return torch.normal(
                mean=self.z_mean,
                std=self.z_std,
                size=x.shape,
                device=x.device
            )
    
    def _sample_k(self):
        """
        Sample number of iterations from Poisson distribution.
        
        Returns:
            k: Number of iterations
        """
        if self.recurrent_k_distribution["distribution"] == "poisson":
            mean = torch.log(self.k_bar) - 0.5 * self.k_sigma**2
            tau = torch.normal(mean=mean.detach(), std=self.k_sigma.detach())
            k = (torch.poisson(torch.exp(tau)) + 1).int()
            return k
    
    def _recurrent_loop(self, x, z, k):
        for _ in range(k):
            z = x+z
            for block in self.recurrent_block:
                z = block(z)
            # #####################
            # # Concat Everywhere #
            # #####################
            # for block in self.recurrent_block:
            #     z = torch.cat([x, z], dim=1)
            #     z = block(z)
            # #####################
            # # Concat Everywhere #
            # #####################
        return z

    def forward(self, x, k=None):
        """
        Forward pass of the recurrent model.
        
        Args:
            x: Input tensor
            k: Optional number of recurrent iterations (for evaluation)
            
        Returns:
            Output tensor
        """
        x = self.encoder(x)

        if self.skip:
            skip = x

        z = self._sample_z(x)
        
        if self.training:
            k = self._sample_k()

            with torch.no_grad():
                z = self._recurrent_loop(x, z, k - self.recurrent_tbptt_steps)

            z = self._recurrent_loop(x, z, min(k, self.recurrent_tbptt_steps))

        else:
            k = int(k if k is not None else self.k_bar)

            z = self._recurrent_loop(x, z, k)
        

        if self.skip:
            z = torch.cat([skip, z], dim=1)

        x = self.decoder(z)


        return x
    
class Encoder(nn.Module):
    def __init__(
            self, 
            dimension: int,
            in_channels: int,
            out_channels: int, 
            encoder_class: str = "ResidualBlock",
            encoder_depth: int = 2,
            **kwargs
    ):
        super().__init__()

        self.dimension = dimension
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.encoder_class = encoder_class
        self.encoder_depth = encoder_depth
        self.kwargs = kwargs

        self.lift = get_component('LiftingLayer')(
            dimension=dimension,
            in_channels=in_channels,
            out_channels=out_channels
        )

        layers = []
        for _ in range(encoder_depth):
            layers.append(
                get_component(encoder_class)(
                    dimension=dimension,
                    in_channels=out_channels,
                    out_channels=out_channels,
                    **kwargs
                )
            )
        self.layers = nn.ModuleList(layers)
            
    def forward(self, x):
        x = self.lift(x)

        for layer in self.layers:
            x = layer(x)
        return x
    
class Decoder(nn.Module):
    def __init__(
            self,
            dimension: int,
            in_channels: int,
            out_channels: int,
            decoder_class: str = "ResidualBlock",
            decoder_depth: int = 2,
            **kwargs
    ):
        super().__init__()

        self.dimension = dimension
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.decoder_class = decoder_class
        self.decoder_depth = decoder_depth
        self.kwargs = kwargs

        layers = []
        for _ in range(decoder_depth):
            layers.append(
                get_component(decoder_class)(
                    dimension=dimension,
                    in_channels=in_channels,
                    out_channels=in_channels,
                    **kwargs
                )
            )
        self.layers = nn.ModuleList(layers)

        self.proj = get_component('ProjectionLayer')(
            dimension=dimension,
            in_channels=in_channels,
            out_channels=out_channels
        )
        
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)        

        x = self.proj(x)
        return x