import torch
import torch.nn as nn

from . import register_wrapper
from ..architectures import get_architecture
from ..components import get_component

@register_wrapper("RecurrentModel")
class RecurrentModel(nn.Module):
    """
    Recurrent Neural Network for PDE modeling.

    This model implements a recurrent architecture with:
    1. An encoder that maps input to a higher-dimensional representation
    2. A recurrent core with stochastic depth during training
    3. A decoder that maps the representation back to the output domain

    The stochastic depth is implemented by:
    - Sampling the number of recurrent iterations (k) from a Poisson distribution
    - Using truncated backpropagation through time for efficiency
    - Allowing deterministic depth at inference time

    Args:
        dimension: Spatial dimension of the problem
        in_channels: Number of input channels
        out_channels: Number of output channels
        hidden_channels: Number of channels in hidden representations
        encoder_class: Block class for encoder
        encoder_depth: Number of blocks in encoder
        recurrent_class: Block class for recurrent core
        recurrent_depth: Number of blocks in recurrent core
        recurrent_z_distribution: Config for initial state distribution
        recurrent_k_distribution: Config for iterations distribution
        recurrent_tbptt_steps: Steps for truncated backpropagation
        decoder_class: Block class for decoder
        decoder_depth: Number of blocks in decoder
        skip: Whether to use skip connections
        ## CONCAT METHODS
        combine_method: Method for combining x and z 
        ## CONCAT METHODS
    """
    def __init__(
            self,
            dimension: int,
            in_channels: int, 
            out_channels: int, 
            hidden_channels: int,
            encoder_class: str = "ResidualBlock",
            encoder_depth: int = 2,
            encoder_kwargs: dict[str, object] | None = None,
            recurrent_class: str = "ResidualBlock",
            recurrent_depth: int = 4,
            recurrent_kwargs: dict[str, object] | None = None,
            recurrent_z_distribution: dict[str, object] | None = None,
            recurrent_k_distribution: dict[str, object] | None = None,
            recurrent_tbptt_steps: int = 8,
            decoder_class: str = "ResidualBlock",
            decoder_depth: int = 2,
            decoder_kwargs: dict[str, object] | None = None,
            skip: bool = False,
            ## CONCAT METHODS
            combine_method: str = "channel_weighted_add",
            ## CONCAT METHODS
            **kwargs
        ):
        super().__init__()

        self.recurrent_z_distribution = {
            **{
                "distribution": "normal",
                "z_mean": 0.0,
                "z_std": 1.0
            },
            **(recurrent_z_distribution or {})
        }
        for key, value in self.recurrent_z_distribution.items():
            if key != "distribution":
                self.register_buffer(f'{key}', torch.tensor(value))

        self.recurrent_k_distribution = {
            **{
                "distribution": "poisson",
                "k_bar": 16.0,
                "k_sigma": 0.5
            },
            **(recurrent_k_distribution or {})
        }
        for key, value in self.recurrent_k_distribution.items():
            if key != "distribution":
                self.register_buffer(f'{key}', torch.tensor(value))
    
        self.recurrent_tbptt_steps = recurrent_tbptt_steps
        self.skip = skip

        ## CONCAT METHODS
        self.combine_method = combine_method
        self.hidden_channels = hidden_channels
        ## CONCAT METHODS

        self.encoder = Encoder(
            dimension=dimension,
            in_channels=in_channels,
            out_channels=hidden_channels,
            encoder_class=encoder_class,
            encoder_depth=encoder_depth,
            **{**kwargs, **(encoder_kwargs or {})}
        )

        try:
            RecurrentClass = get_architecture(recurrent_class)
        except KeyError:
            try:
                RecurrentClass = get_component(recurrent_class)
            except KeyError:
                raise ValueError(f"Recurrent block class {recurrent_class} not found in any registry.")

        recurrent_block = []
        
        ## CONCAT METHODS
        # For the original concatenation method
        if combine_method == "concat":
        ## CONCAT METHODS
            recurrent_block.append(RecurrentClass(
                dimension=dimension,
                in_channels=hidden_channels * 2,
                out_channels=hidden_channels,
                **{**kwargs, **(recurrent_kwargs or {})}
            ))
            for _ in range(recurrent_depth - 1):
                recurrent_block.append(RecurrentClass(
                    dimension=dimension,
                    in_channels=hidden_channels,
                    out_channels=hidden_channels,
                    **{**kwargs, **(recurrent_kwargs or {})}
                ))
        ## CONCAT METHODS
        else:
            # For the other methods (addition, weighted addition, projection)
            for _ in range(recurrent_depth):
                recurrent_block.append(RecurrentClass(
                    dimension=dimension,
                    in_channels=hidden_channels,
                    out_channels=hidden_channels,
                    **{**kwargs, **(recurrent_kwargs or {})}
                ))
        ## CONCAT METHODS
        
        # #####################
        # # Concat Everywhere #
        # #####################
        # for _ in range(recurrent_depth):
        #     recurrent_block.append(RecurrentClass(
        #         dimension=dimension,
        #         in_channels=hidden_channels * 2,
        #         out_channels=hidden_channels,
        #         **{**kwargs, **(recurrent_kwargs or {})}
        #     ))
        # #####################
        # # Concat Everywhere #
        # #####################
        self.recurrent_block = nn.ModuleList(recurrent_block)
        
        ## CONCAT METHODS
        # Add learnable parameters for the different combination methods
        if combine_method == "weighted_add":
            # Method 2: Learnable weights for addition - scalar weights
            self.weight_x = nn.Parameter(torch.ones(1))
            self.weight_z = nn.Parameter(torch.ones(1))
        elif combine_method == "projection":
            # Method 3: Point-wise projection
            if dimension == 1:
                self.projection = nn.Conv1d(hidden_channels * 2, hidden_channels, kernel_size=1, bias=False)
            elif dimension == 2:
                self.projection = nn.Conv2d(hidden_channels * 2, hidden_channels, kernel_size=1, bias=False)
            elif dimension == 3:
                self.projection = nn.Conv3d(hidden_channels * 2, hidden_channels, kernel_size=1, bias=False)
            else:
                raise ValueError(f"Unsupported dimension: {dimension}")
            
            # Initialize weights to match channel_weighted_add behavior
            with torch.no_grad():
                # Initialize weights to zeros
                nn.init.zeros_(self.projection.weight)
                # Set diagonal weights to 1 for both x and z channels
                for i in range(hidden_channels):
                    self.projection.weight[i, i, ...] = 1.0  # x channels
                    self.projection.weight[i, i + hidden_channels, ...] = 1.0  # z channels
        elif combine_method == "channel_weighted_add":
            # Method 4: Channel-wise weighted addition
            if dimension == 1:
                self.weight_x_channel = nn.Parameter(torch.ones(1, hidden_channels, 1))
                self.weight_z_channel = nn.Parameter(torch.ones(1, hidden_channels, 1))
            elif dimension == 2:
                self.weight_x_channel = nn.Parameter(torch.ones(1, hidden_channels, 1, 1))
                self.weight_z_channel = nn.Parameter(torch.ones(1, hidden_channels, 1, 1))
            elif dimension == 3:
                self.weight_x_channel = nn.Parameter(torch.ones(1, hidden_channels, 1, 1, 1))
                self.weight_z_channel = nn.Parameter(torch.ones(1, hidden_channels, 1, 1, 1))
            else:
                raise ValueError(f"Unsupported dimension: {dimension}")
        ## CONCAT METHODS

        self.decoder = Decoder(
            dimension=dimension,
            in_channels=hidden_channels if not skip else hidden_channels * 2,
            out_channels=out_channels,
            decoder_class=decoder_class,
            decoder_depth=decoder_depth,
            **{**kwargs, **(decoder_kwargs or {})}
        )
    
    def _sample_z(self, x):
        """
        Sample initial state from normal distribution.

        Args:
            x: Input tensor to match device and shape
        """
        if self.recurrent_z_distribution["distribution"] == "normal":
            return torch.normal(
                mean=self.z_mean,
                std=self.z_std,
                size=x.shape,
                device=x.device
            )
    
    def _sample_k(self):
        """
        Sample number of iterations from Poisson distribution.
        
        Returns:
            k: Number of iterations
        """
        if self.recurrent_k_distribution["distribution"] == "poisson":
            # cutoff = 2 * self.k_bar
            mean = torch.log(self.k_bar) - 0.5 * self.k_sigma**2
            while True:
                tau = torch.normal(mean=mean.detach(), std=self.k_sigma.detach())
                k = (torch.poisson(torch.exp(tau)) + 1).int()
                # if k <= cutoff:
                #     break
                break
            return k
    
    ## CONCAT METHODS
    def _combine_x_z(self, x, z):
        """
        Combine input x and state z using one of several methods.
        
        Args:
            x: Input tensor (encoder output)
            z: Current state tensor
            
        Returns:
            Combined tensor according to the specified method
        """
        if self.combine_method == "concat":
            # Original method: concatenate along channel dimension
            return torch.cat([x, z], dim=1)
        elif self.combine_method == "add":
            # Method 1: Simple addition
            return x + z
        elif self.combine_method == "weighted_add":
            # Method 2: Weighted addition with scalar learnable parameters
            return self.weight_x * x + self.weight_z * z
        elif self.combine_method == "projection":
            # Method 3: Point-wise projection
            combined = torch.cat([x, z], dim=1)
            return self.projection(combined)
        elif self.combine_method == "channel_weighted_add":
            # Method 4: Channel-wise weighted addition
            return self.weight_x_channel * x + self.weight_z_channel * z
        else:
            raise ValueError(f"Unknown combination method: {self.combine_method}")
    ## CONCAT METHODS
    
    def _recurrent_loop(self, x, z, k):
    ## CONCAT METHODS
        # Original implementation
        # for _ in range(k):
        #     z = torch.cat([x, z], dim=1)
        #     for block in self.recurrent_block:
        #         z = block(z)
        # return z
        
        # New implementation with different combination methods
    ## CONCAT METHODS
        for _ in range(k):
    ## CONCAT METHODS
            if self.combine_method == "concat":
                z = self._combine_x_z(x, z)
                for block in self.recurrent_block:
                    z = block(z)
            else:
                z = self._combine_x_z(x, z)
                for block in self.recurrent_block:
                    z = block(z)
    ## CONCAT METHODS
            # #####################
            # # Concat Everywhere #
            # #####################
            # for block in self.recurrent_block:
            #     z = torch.cat([x, z], dim=1)
            #     z = block(z)
            # #####################
            # # Concat Everywhere #
            # #####################
        return z

    def forward(self, x, k=None):
        """
        Forward pass of the recurrent model.
        
        Args:
            x: Input tensor
            k: Optional number of recurrent iterations (for evaluation)
            
        Returns:
            Output tensor
        """
        x = self.encoder(x)

        if self.skip:
            skip = x

        z = self._sample_z(x)
        
        if self.training:
            k = self._sample_k()

            ## If k is larger than tbptt steps, then we run recurrent loop for k - tbptt steps with no gradient.
            ## and then we run the recurrent loop for tbptt steps with gradient.
            ## e.g., k = 10, tbptt_steps = 4
            ## then we run the recurrent loop for 6 steps with no gradient.
            ## and then we run the recurrent loop for 4 steps with gradient.

            ## If k is smaller than tbptt steps, then we run the recurrent loop for k steps with gradient.
            ## e.g., k = 4, tbptt_steps = 10
            ## then we run the recurrent loop for 4 steps with gradient.

            ## Modification: 
            ## If k is smaller than tbptt steps, then we run the recurrent loop for tbptt steps with gradient.
            ## e.g., k = 4, tbptt_steps = 10
            ## then we run the recurrent loop for 10 steps with gradient.

            with torch.no_grad():
                z = self._recurrent_loop(x, z, k - self.recurrent_tbptt_steps)

            # z = self._recurrent_loop(x, z, min(k, self.recurrent_tbptt_steps)) 
            z = self._recurrent_loop(x, z, self.recurrent_tbptt_steps)

        else:
            k = int(k if k is not None else self.k_bar)

            z = self._recurrent_loop(x, z, k)

        if self.skip:
            z = torch.cat([skip, z], dim=1)

        x = self.decoder(z)

        return x
    
class Encoder(nn.Module):
    def __init__(
            self, 
            dimension: int,
            in_channels: int,
            out_channels: int, 
            encoder_class: str = "ResidualBlock",
            encoder_depth: int = 2,
            **kwargs
    ):
        super().__init__()

        self.dimension = dimension
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.encoder_class = encoder_class
        self.encoder_depth = encoder_depth
        self.kwargs = kwargs

        self.lift = get_component('LiftingLayer')(
            dimension=dimension,
            in_channels=in_channels,
            out_channels=out_channels
        )

        layers = []
        for _ in range(encoder_depth):
            layers.append(
                get_component(encoder_class)(
                    dimension=dimension,
                    in_channels=out_channels,
                    out_channels=out_channels,
                    **kwargs
                )
            )
        self.layers = nn.ModuleList(layers)
            
    def forward(self, x):
        x = self.lift(x)

        for layer in self.layers:
            x = layer(x)
        return x
    
class Decoder(nn.Module):
    def __init__(
            self,
            dimension: int,
            in_channels: int,
            out_channels: int,
            decoder_class: str = "ResidualBlock",
            decoder_depth: int = 2,
            **kwargs
    ):
        super().__init__()

        self.dimension = dimension
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.decoder_class = decoder_class
        self.decoder_depth = decoder_depth
        self.kwargs = kwargs

        layers = []
        for _ in range(decoder_depth):
            layers.append(
                get_component(decoder_class)(
                    dimension=dimension,
                    in_channels=in_channels,
                    out_channels=in_channels,
                    **kwargs
                )
            )
        self.layers = nn.ModuleList(layers)

        self.proj = get_component('ProjectionLayer')(
            dimension=dimension,
            in_channels=in_channels,
            out_channels=out_channels
        )
        
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)        

        x = self.proj(x)
        return x