from torch import nn

# from utils.config import instantiate_name, instantiate_cls
import src.utils as utils
from src.models.nn.components import Normalization
from src.models.sequence import SequenceModule
from src.models.sequence.pool import registry as pool_registry
from src.models.nn.residual import registry as residual_registry
import src.utils.registry as registry


class SequenceResidualBlock(SequenceModule):
    # @property
    # def default_residual(self):
    #     return { '_name_': 'N' }

    # @property
    # def default_norm(self):
    #     return { '_name_': 'none' }

    def __init__(
            self,
            d_input,
            i_layer=None, # Only needs to be passed into certain residuals like Decay
            prenorm=True,
            dropout=0.0,
            layer=None, # Config for black box module
            residual=None, # Config for residual function
            norm=None, # Config for normalization layer
            pool=None,
        ):
        super().__init__()

        self.i_layer = i_layer
        self.d_input = d_input
        # self.layer = instantiate_name(registry.layer, layer, d_input)
        self.layer = utils.instantiate(registry.layer, layer, d_input)
        self.prenorm = prenorm

        # Residual
        # d_residual is the output dimension after residual
        if residual is None:
            self.residual = None
            self.d_residual = self.layer.d_output
        else:
            self.residual = utils.instantiate(residual_registry, residual, i_layer, d_input, self.layer.d_output)
            # if isinstance(residual, str):
            #     self.residual = residual_registry[residual](i_layer, d_input)
            # else:
            #     self.residual = residual_registry[residual.pop('_name_')](i_layer, d_input, **residual)
            self.d_residual = self.residual.d_output

        # Normalization
        d_norm = d_input if self.prenorm else self.d_residual
        # We don't use config to directly instantiate since Normalization has some special cases
        if norm is None:
            self.norm = None
        elif isinstance(norm, str):
            self.norm = Normalization(d_norm, transposed=self.transposed, _name_=norm)
        else:
            self.norm = Normalization(d_norm, transposed=self.transposed, **norm)

        # Pool
        # self.pool = pool_registry[pool.pop('_name_')](self.d_residual, **pool) if pool is not None else None
        # self.pool = utils.instantiate(pool_registry, pool, self.d_residual) if pool is not None else None
        self.pool = utils.instantiate(pool_registry, pool, self.d_residual, transposed=self.transposed)

        # Dropout
        drop_cls = nn.Dropout2d if self.transposed else nn.Dropout
        self.drop = drop_cls(dropout) if dropout > 0.0 else nn.Identity()


    @property
    def transposed(self):
        return getattr(self.layer, 'transposed', False)

    @property
    def d_output(self):
        # return None if self.d_residual is None else self.d_residual * self.expand
        return self.pool.d_output if self.pool is not None else self.d_residual

    @property
    def d_state(self):
        return self.layer.d_state

    @property
    def state_to_tensor(self):
        return self.layer.state_to_tensor

    def default_state(self, *args, **kwargs):
        return self.layer.default_state(*args, **kwargs)

    def forward(self, x, *args, state=None, **kwargs):
        y = x

        # Pre-norm
        if self.norm is not None and self.prenorm: y = self.norm(y)

        # Black box module
        y, state = self.layer(y, *args, state=state, **kwargs)

        # Residual
        if self.residual is not None: x = self.residual(x, self.drop(y), self.transposed)

        # Post-norm
        if self.norm is not None and not self.prenorm: x = self.norm(x)

        # Pool
        # x = pool.downpool(x, self.pool, self.expand, self.transposed)
        if self.pool is not None: x = self.pool(x)

        return x, state

    def step(self, x, state, *args, **kwargs): # TODO needs fix for transpose logic
        y = x

        # Pre-norm
        if self.norm is not None and self.prenorm:
            if self.transposed: y = y.unsqueeze(-1)
            y = self.norm(y) # TODO transpose seems wrong
            if self.transposed: y = y.squeeze(-1)

        # Black box module
        y, state = self.layer.step(y, state, *args, **kwargs)

        # Residual
        if self.residual is not None: x = self.residual(x, y, transposed=False) # TODO this would not work with concat

        # Post-norm
        if self.norm is not None and not self.prenorm:
            if self.transposed: y = y.unsqueeze(-1)
            x = self.norm(x)#.step(x)
            if self.transposed: y = y.squeeze(-1)

        # Pool
        if self.pool is not None: x = self.pool(x)

        return x, state

