from collections import OrderedDict
from torch import tensor
from torch.nn import init
#from torch.nn.utils.parametrizations import weight_norm as wn

from ..elements.layers import *
from ..utils import SerializableModule, SerializableSequential as Sequential, unit_weight_norm
from .. import Hyperparams


def get_net(model):
    """
    Get net from
    -string model type,
    -hyperparameter config
    -SerializableModule or SerializableSequential object
    -or list of the above

    :param model: str, Hyperparams, SerializableModule, SerializableSequential, list
    :return: SerializableSequential
    """

    if model is None:
        return Sequential()

    # Load model from hyperparameter config
    elif isinstance(model, Hyperparams):

        if "type" not in model.keys():
            raise ValueError("Model type not specified.")
        if model.type == 'mlp':
            return MLPNet.from_hparams(model)
        elif model.type == 'conv':
            return ConvNet.from_hparams(model)
        elif model.type == 'deconv':
            return DeconvNet.from_hparams(model)
        else:
            raise NotImplementedError("Model type not supported.")

    elif isinstance(model, BlockNet):
        return model

    # Load model from SerializableModule
    elif isinstance(model, SerializableModule):
        return model

    # Load model from SerializableSequential
    elif isinstance(model, Sequential):
        return model

    # Load model from list of any of the above
    elif isinstance(model, list):
        # Load model from list
        return Sequential(*list(map(get_net, model)))

    else:
        raise NotImplementedError("Model type not supported.")


class MLPNet(SerializableModule):

    """
    Parametric multilayer perceptron network

    :param input_size: int, the size of the input
    :param hidden_sizes: list of int, the sizes of the hidden layers
    :param output_size: int, the size of the output
    :param residual: bool, whether to use residual connections
    :param activation: torch.nn.Module, the activation function to use
    """
    def __init__(self, input_size, hidden_sizes, output_size, residual=False, activation=nn.ReLU(), activate_output=True, weight_norm=False, init=None):
        super(MLPNet, self).__init__()
        self.input_size = input_size
        self.hidden_sizes = hidden_sizes
        self.output_size = output_size
        self.activation = activation
        self.residual = residual
        self.activate_output = activate_output
        self.weight_norm = weight_norm

        layers = []
        sizes = [input_size] + hidden_sizes + [output_size]
        for i in range(len(sizes) - 1):
            l = nn.Linear(sizes[i], sizes[i + 1])
            if self.weight_norm:
                l = unit_weight_norm(l, dim=0)
            layers.append(l)
            if i < len(sizes) - 2 or self.activate_output:
                layers.append(self.activation)

        self.mlp_layers = nn.Sequential(*layers)
        self.init = init
        if init is not None:
            self.initialize_parameters(method=init)

    def forward(self, inputs):
        x = inputs
        #import pdb; pdb.set_trace()
        x = self.mlp_layers(x)
        outputs = x if not self.residual else inputs + x
        return outputs


    @staticmethod
    def from_hparams(hparams: Hyperparams):
        return MLPNet(
            input_size=hparams.input_size,
            hidden_sizes=hparams.hidden_sizes,
            output_size=hparams.output_size,
            activation=hparams.activation,
            residual=hparams.residual,
            activate_output=hparams.activate_output
        )

    def serialize(self):
        return dict(
            type=self.__class__,
            state_dict=self.state_dict(),
            params=dict(
                input_size=self.input_size,
                hidden_sizes=self.hidden_sizes,
                output_size=self.output_size,
                activation=self.activation,
                residual=self.residual,
                activate_output=self.activate_output,
                weight_norm=self.weight_norm
            )
        )

    @staticmethod
    def deserialize(serialized):
        serialized["params"].pop("output_activation", None)
        net = MLPNet(**serialized["params"])
        net.load_state_dict(serialized["state_dict"])
        return net
    
    def initialize_parameters(self, method='xavier_uniform'):
        """
        Initialize the parameters of a PyTorch module using the specified method.

        Parameters:
        - module (torch.nn.Module): The PyTorch module whose parameters need to be initialized.
        - method (str): The initialization method. Default is 'xavier_uniform'.
                       Other options include 'xavier_normal', 'kaiming_uniform', 'kaiming_normal',
                       'orthogonal', 'uniform', 'normal', etc.

        Returns:
        None
        """
        for name, param in self.named_parameters():
            if 'weight' in name:
                if method == 'xavier_uniform':
                    init.xavier_uniform_(param)
                elif method == 'xavier_normal':
                    init.xavier_normal_(param)
                elif method == 'kaiming_uniform':
                    init.kaiming_uniform_(param, mode='fan_in', nonlinearity='relu')
                elif method == 'kaiming_normal':
                    init.kaiming_normal_(param, mode='fan_in', nonlinearity='relu')
                elif method == 'orthogonal':
                    init.orthogonal_(param)
                elif method == 'uniform':
                    init.uniform_(param, a=0.0, b=1.0)
                elif method == 'normal':
                    init.normal_(param, mean=0.0, std=1.0)
                else:
                    raise ValueError(f"Unsupported initialization method: {method}")
                
    def extra_repr(self) -> str:
        return super().extra_repr() + f"residual={self.residual}, weight_norm={self.weight_norm}"
    


class ConvNet(SerializableModule):
    """
    Parametric convolutional network
    based on Efficient-VDVAE paper

    :param in_filters: int, the number of input filters
    :param filters: list of int, the number of filters for each layer
    :param kernel_size: int or tuple of int, the size of the convolutional kernel
    :param pool_strides: int or tuple of int, the strides for the pooling layers
    :param unpool_strides: int or tuple of int, the strides for the unpooling layers
    :param activation: torch.nn.Module, the activation function to use
    """
    def __init__(self, in_filters, filters, kernel_size, pools=None, strides=None,
                 activation=nn.Softplus(), activate_output=False,
                 norm_class = None, init=None, weight_norm=False):
        super(ConvNet, self).__init__()

        self.in_filters = in_filters
        self.filters = filters
        assert len(self.filters) > 0, "Must have at least one filter"

        if isinstance(kernel_size, int):
            kernel_size = [(kernel_size, kernel_size)
                           for _ in range(len(self.filters))]
        elif isinstance(kernel_size, list):
            kernel_size = [(ks, ks) if isinstance(ks, int) else ks
                           for ks in kernel_size]
        self.kernel_size = kernel_size

        if pools is None:
            pools = []
        self.pools = pools
        self.strides = strides
        self.activation = activation
        self.activate_output = activate_output
        self.norm_class = norm_class
        self.weight_norm = weight_norm

        def stride_padding_kernel(i: int):
            kernel = self.kernel_size[i]
            if self.strides is not None:
                stride = self.strides[i]
            else:
                if i in self.pools:
                    stride = 2 
                else:
                    stride = 1
            
            if stride == 1:
                padding = "same"
            else:
                padding1 = (kernel[0] - 1) // 2
                padding2 = (kernel[1] - 1) // 2
                padding = (padding1, padding2)

            return dict(stride=stride, padding=padding, kernel_size=kernel)

        i = 0
        conv = nn.Conv2d(in_channels=self.in_filters,
                      out_channels=self.filters[0],
                      **stride_padding_kernel(i))
        if self.weight_norm:
            conv = unit_weight_norm(conv, dim=0)
        convs = nn.Sequential(conv)
        for i in range(len(self.filters)-1):
            if self.norm_class is not None:
                convs.append(self.norm_class(self.filters[i])),
            convs.append(self.activation)
            conv = nn.Conv2d(in_channels=self.filters[i],
                                   out_channels=self.filters[i + 1],
                                   **stride_padding_kernel(i+1))
            if self.weight_norm:
                conv = unit_weight_norm(conv, dim=0)
            convs.append(conv)
        if self.activate_output:
            convs.append(self.activation)

        self.convs = nn.Sequential(*convs)
        self.init = init
        if init is not None:
            self.initialize_parameters(method=init)


    def forward(self, inputs):
        x = self.convs(inputs)
        return x

    def initialize_parameters(self, method='xavier_uniform'):
        """
        Initialize the parameters of a PyTorch module using the specified method.

        Parameters:
        - module (torch.nn.Module): The PyTorch module whose parameters need to be initialized.
        - method (str): The initialization method. Default is 'xavier_uniform'.
                       Other options include 'xavier_normal', 'kaiming_uniform', 'kaiming_normal',
                       'orthogonal', 'uniform', 'normal', etc.

        Returns:
        None
        """
        for name, param in self.named_parameters():
            if 'weight' in name and 'norm' not in name:
                if method == 'xavier_uniform':
                    init.xavier_uniform_(param)
                elif method == 'xavier_normal':
                    init.xavier_normal_(param)
                elif method == 'kaiming_uniform':
                    init.kaiming_uniform_(param, mode='fan_in', nonlinearity='relu')
                elif method == 'kaiming_normal':
                    init.kaiming_normal_(param, mode='fan_in', nonlinearity='relu')
                elif method == 'orthogonal':
                    init.orthogonal_(param)
                elif method == 'uniform':
                    init.uniform_(param, a=0.0, b=1.0)
                elif method == 'normal':
                    init.normal_(param, mean=0.0, std=1.0)
                else:
                    raise ValueError(f"Unsupported initialization method: {method}")

    @staticmethod
    def from_hparams(hparams):
        pools = hparams.pools if "pools" in hparams.keys() is not None else []
        norm_class = nn.BatchNorm2d if "batchnorm" in hparams.keys() \
                        else hparams.norm_class if "norm_class" in hparams.keys() \
                            else None
        strides = hparams.strides if "strides" in hparams.keys() is not None else None
        init = hparams.init if "init" in hparams.keys() else "xavier_uniform"
        return ConvNet(
            in_filters=hparams.in_filters,
            filters=hparams.filters,
            kernel_size=hparams.kernel_size,
            pools=pools,
            norm_class=norm_class,
            strides=strides,
            init=init,
            activation=hparams.activation,
            activate_output=hparams.activate_output,
        )

    def serialize(self):
        return dict(
            type=self.__class__,
            state_dict=self.state_dict(),
            params=dict(
                in_filters=self.in_filters,
                filters=self.filters,
                kernel_size=self.kernel_size,
                pools=self.pools,
                norm_class=self.norm_class,
                strides=self.strides,
                init=self.init,
                activation=self.activation,
                activate_output=self.activate_output,
                weight_norm=self.weight_norm
            )
        )

    @staticmethod
    def deserialize(serialized):
        net = ConvNet(**serialized["params"])
        net.load_state_dict(serialized["state_dict"])
        return net
    
    def extra_repr(self) -> str:
        return super().extra_repr() + f" init={self.init}"


class DeconvNet(SerializableModule):
    """
    Parametric convolutional network
    based on Efficient-VDVAE paper

    :param in_filters: int, the number of input filters
    :param filters: list of int, the number of filters for each layer
    :param kernel_size: int or tuple of int, the size of the convolutional kernel
    :param pool_strides: int or tuple of int, the strides for the pooling layers
    :param unpool_strides: int or tuple of int, the strides for the unpooling layers
    :param activation: torch.nn.Module, the activation function to use
    """
    def __init__(self, in_filters, filters, kernel_size, unpools=None,
                 activation=nn.Softplus(), activate_output=False):
        super(DeconvNet, self).__init__()

        self.in_filters = in_filters
        self.filters = filters
        assert len(self.filters) > 0, "Must have at least one filter"

        if isinstance(kernel_size, int):
            kernel_size = [(kernel_size, kernel_size)] * (len(self.filters))
        elif isinstance(kernel_size, list):
            kernel_size = [(ks, ks) if not isinstance(ks, tuple) else ks for ks in kernel_size]
        self.kernel_size = kernel_size
        output_paddings = [0 if self.kernel_size[i][0] % 2 == 0 else 1 
                           for i in range(len(self.kernel_size))]

        if unpools is None:
            unpools = []
        self.unpools = unpools
        unpool_layers = [x[0] for x in self.unpools]
        unpool_strides = [x[1] for x in self.unpools]

        self.activation = activation
        self.activate_output = activate_output

        i = 0
        convs = nn.Sequential()
        if i in unpool_layers:
            convs.append(
                nn.ConvTranspose2d(in_channels=self.in_filters,
                                   out_channels=self.filters[0],
                                   kernel_size=self.kernel_size[0],
                                   stride=unpool_strides[0],
                                   output_padding=output_paddings[0], padding=1),
                )
            #  convs.append(nn.BatchNorm2d(self.in_filters))
            if len(filters) != 1 or self.activate_output:
                convs.append(self.activation)
                convs.append(nn.Conv2d(in_channels=self.filters[0],
                                       out_channels=self.filters[0],
                                       kernel_size=self.kernel_size[i+1],
                                       padding="same"))

        for i in range(len(self.filters)-1):
            #  convs.append(nn.BatchNorm2d(self.filters[i]))
            convs.append(self.activation)

            if i + 1 in unpool_layers:
                convs.append(
                    nn.ConvTranspose2d(in_channels=self.filters[i],
                                       out_channels=self.filters[i],
                                       kernel_size=self.kernel_size[0],
                                       stride=unpool_strides[0], 
                                       output_padding=output_paddings[0]),
                )
                #  convs.append(nn.BatchNorm2d(self.filters[i]))
                convs.append(self.activation)
            convs.append(nn.Conv2d(in_channels=self.filters[i],
                                   out_channels=self.filters[i + 1],
                                   kernel_size=self.kernel_size[i+1],
                                   padding="same"))

        if self.activate_output:
            convs.append(self.activation)

        self.convs = nn.Sequential(*convs)

    def forward(self, inputs):
        x = self.convs(inputs)
        return x

    @staticmethod
    def from_hparams(hparams):
        return DeconvNet(
            in_filters=hparams.in_filters,
            filters=hparams.filters,
            kernel_size=hparams.kernel_size,
            unpools=hparams.unpools,
            activation=hparams.activation,
            activate_output=hparams.activate_output,
        )

    def serialize(self):
        return dict(
            type=self.__class__,
            state_dict=self.state_dict(),
            params=dict(
                in_filters=self.in_filters,
                filters=self.filters,
                kernel_size=self.kernel_size,
                unpools=self.unpools,
                activation=self.activation,
                activate_output=self.activate_output,
            )
        )

    @staticmethod
    def deserialize(serialized):
        net = DeconvNet(**serialized["params"])
        net.load_state_dict(serialized["state_dict"])
        return net
    
    

class MlpMixer(SerializableModule):
    
    """
    MLP-Mixer block
    
    0. input: (batch_size, num_patches, channels)
    1. transpose: (batch_size, channels, num_patches)
    2. token_mixing: (batch_size, channels, num_patches)
    3. transpose: (batch_size, num_patches, channels)
    4. add: (batch_size, num_patches, channels)
    5. channel_mixing: (batch_size, num_patches, channels)
    6. add: (batch_size, num_patches, channels)
    
    """
    
    def __init__(self, layer_norm=True, **mlp_params):
        super(MlpMixer, self).__init__()
        self.layer_norm = layer_norm
        self.mlp_params = mlp_params
        self.token_mixing = MLPNet(**mlp_params)
        self.channel_mixing = MLPNet(**mlp_params)

    def forward(self, x):
        # x: (batch_size, num_patches, channels)
        if self.layer_norm:
            x = nn.LayerNorm(x.size()[1:], eps=1e-6)(x)
        y = torch.transpose(x, 1, 2)
        # y: (batch_size, channels, num_patches)
        y = self.token_mixing(y)
        y = torch.transpose(y, 1, 2)
        # y: (batch_size, num_patches, channels)
        x = x + y
        if self.layer_norm:
            x = nn.LayerNorm(x.size()[1:], eps=1e-6)(x)
        y = self.channel_mixing(x)
        return x + y
    
    def serialize(self):
        return dict(
            type=self.__class__,
            state_dict=self.state_dict(),
            params=dict(
                layer_norm=self.layer_norm,
                mlp_params=self.mlp_params
            )
        )

    @staticmethod
    def deserialize(serialized):
        net = MlpMixer(**serialized["params"])
        net.load_state_dict(serialized["state_dict"])
        return net


class BlockNet(SerializableModule):

    def __init__(self, **blocks):
        from ..block import InputBlock
        super(BlockNet, self).__init__()

        self.input_block, output = next(((block, output) for output, block in blocks.items()
                                         if isinstance(block, InputBlock)), None)
        self.input_block.set_output(output)
        self.output_block = next((block for _, block in blocks.items()
                                  if isinstance(block, self.OutputBlock)), None)

        self.blocks = nn.ModuleDict()
        for output, block in blocks.items():
            if not isinstance(block, (InputBlock, self.OutputBlock)):
                block.set_output(output)
                self.blocks.update({output: block})

    def forward(self, inputs):
        computed = self.input_block(inputs)
        computed = self.propogate_blocks(computed)
        output = self.output_block(computed)
        return output

    def propogate_blocks(self, computed):
        for block in self.blocks.values():
            output = block(computed=computed)
            if isinstance(output, tuple):
                computed, _ = output
            else:
                computed = output
        return computed

    def serialize(self):
        blocks = list()
        blocks.append(self.input_block.serialize())
        for block in self.blocks.values():
            blocks.append(block.serialize())
        blocks.append(self.output_block.serialize())
        return dict(
            type=self.__class__,
            blocks=blocks
        )

    @staticmethod
    def deserialize(serialized):
        blocks = OrderedDict()
        for block in serialized["blocks"]:
            blocks[block["output"]] = block["type"].deserialize(block)
        return BlockNet(**blocks)

    class OutputBlock(SerializableModule):
        """
        Final block of the model
        Functions like a SimpleBlock
        Only for use in BlockNet
        """
        def __init__(self, input_id):
            super(BlockNet.OutputBlock, self).__init__()
            self.input = input_id

        def forward(self, computed: dict) -> (tensor, dict, tuple):
            output = computed[self.input]
            return output

        def serialize(self):
            return dict(
                type=self.__class__,
                input=self.input,
                output="output"
            )

        @staticmethod
        def deserialize(serialized: dict):
            return BlockNet.OutputBlock(input_id=serialized["input"])




