from ast import Raise
import torch.nn as nn
import torch.nn.functional as F

from imagegym.config import cfg
from imagegym.models.act import act_dict
from torch.nn.utils import weight_norm
import torch

def process_cat_cond_params(mlp_params, dims=None,input_dim=None):
    assert 'layers' in mlp_params
    assert isinstance(mlp_params['layers'], list)
    
    input_coor_dim = input_dim

    input_dim = input_coor_dim  ## encoded_coors_x -> projection of them which are smaller
    c_dim_list = [input_dim] + mlp_params['layers'] + [mlp_params['output_dim']]
    mlp_params['c_dim_list'] = c_dim_list
    mlp_params['dims'] = dims    


def process_cat_encoder_x_params(mlp_params, dims=None,coordinate_dim=None):
    assert 'layers' in mlp_params
    assert isinstance(mlp_params['layers'], list)
    if cfg.model.name_encoding == 'fourier':
        input_coor_dim = 2*cfg.params_encoding.num_frequencies
    else:
        input_coor_dim = coordinate_dim

    input_dim = input_coor_dim  ## encoded_coors_x -> projection of them which are smaller
    c_dim_list = [input_dim] + mlp_params['layers']
    mlp_params['c_dim_list'] = c_dim_list
    mlp_params['dims'] = dims

def process_cat_encoder_prior_params(mlp_params, dims=None, coordinate_dim=None):
    assert 'layers' in mlp_params
    assert isinstance(mlp_params['layers'], list)
 
    if cfg.model.name_encoding == 'fourier':
        input_coor_dim = 2*cfg.params_encoding.num_frequencies
    else:
        input_coor_dim = coordinate_dim


    if cfg.model.simple_cat ==True:
        if cfg.model.cat_encoder_x ==True:
            input_dim = cfg.params_cat_x["layers"][-1]
        else:
            input_dim = input_coor_dim  # p(c_d|x_d) 
    else:
        if cfg.model.cat_encoder_x ==True:
            input_dim = cfg.model.dim_z + cfg.params_cat_x["layers"][-1] # p(c_d|x_d_projected,z) 
        else:
            input_dim = cfg.model.dim_z + input_coor_dim  ## p(c_d|x_d,z) concats coordinates and z

    output_dim = cfg.params_k_mixture.K
    c_dim_list = [input_dim] + mlp_params['layers'] + [output_dim]
    mlp_params['c_dim_list'] = c_dim_list
    mlp_params['dims'] = dims


def process_cat_encoder_post_params(mlp_params, dims=None, coordinate_dim=None):
    assert 'layers' in mlp_params
    assert isinstance(mlp_params['layers'], list)
    #now we are also adding dim_z to posterior of the categorical

    if cfg.model.name_encoding == 'fourier':
        input_coor_dim = 2*cfg.params_encoding.num_frequencies
    else:
        input_coor_dim = coordinate_dim

    if cfg.model.post_cat_has_z==True:
        if cfg.model.simple_cat ==True:
            if cfg.model.cat_encoder_x ==True:
                input_dim = cfg.model.dim_z + cfg.params_cat_x["layers"][-1] #q(c_d|x_d_projected,z)
            else:
                input_dim = cfg.model.dim_z +  input_coor_dim # q(c_d|x_d,z)
        else:
            if cfg.model.cat_encoder_x ==True:
                input_dim = cfg.model.dim_z + cfg.dataset.dims[0] + cfg.params_cat_x["layers"][-1] #q(c_d|x_d_projected,y,z)
            else: 
                input_dim = cfg.model.dim_z + input_coor_dim + cfg.dataset.dims[0] # q(c_d|x_d,y,z)
    else:
        raise NotImplementedError
        # input_dim =  2*cfg.params_encoding.num_frequencies + cfg.dataset.dims[0] #concats coordinates and channels this does not use Z
    output_dim = cfg.params_k_mixture.K
    c_dim_list = [input_dim] + mlp_params['layers'] + [output_dim]
    mlp_params['c_dim_list'] = c_dim_list
    mlp_params['dims'] = dims


# Helper function for creating dense blocks
def basic_dense_block(in_dim, out_dim, activation, drop_rate=0.0, use_batch_norm=True, **kwargs):
    # torch.manual_seed(1)
    dense = nn.Linear(in_dim, out_dim)
    drop_layer = nn.Dropout(drop_rate, inplace=False) if drop_rate > 0 else nn.Identity()
    if use_batch_norm:
        bn_layer = nn.BatchNorm1d(out_dim)
    else:
        bn_layer = nn.Identity()
    return nn.Sequential(dense, bn_layer, activation, drop_layer)


#MLP identity
class Identity_layer(nn.Module):

    def __init__(self,
                 **kwargs):
        super(Identity_layer, self).__init__()
        #assert isinstance(hidden_dims, list)
        self.layers = [nn.Identity()]


    def forward(self, x):
        batch_size = x.size(0)
        x = x.view(batch_size, -1)
        # for layer in self.layers:
            # x = layer(x)
        logits = x
        return logits

# class MLP(nn.Module):

#     def __init__(self, input_dim, hidden_dims,
#                  act='relu',
#                  batchnorm=False,
#                  dropout=0.0,
#                  dims=None,
#                  shared_layers=[0,1],
#                  **kwargs):


class MLP(nn.Module):

    def __init__(self, c_dim_list,
                 act='prelu',
                 batchnorm=True,
                 dropout=0.0,
                 output_dropout=0.0,
                 l2norm=False,
                 shared_layers=[],
                 dims=None,
                 mode='',
                 **kwargs):
        super(MLP, self).__init__()
        assert isinstance(c_dim_list, list)
        self.n_layers = len(c_dim_list) - 1
        self.mode = mode
        self.shared_layers = shared_layers if shared_layers is not None else []
        self.non_shared_layers = [i for i in range(self.n_layers) if i not in self.shared_layers]
        self.all_layers =  self.shared_layers + self.non_shared_layers

        self.all_layer_names = [f'w{i}' for i in range(self.n_layers)] + [f'b{i}' for i in range(self.n_layers)]
        self.non_shared_layer_names = []
        for modes in mode:
            self.non_shared_layer_names += [f'{modes}{i}' for i in self.non_shared_layers]
        self.shared_layer_names = list(set(self.all_layer_names) - set(self.non_shared_layer_names))

        layers = []
        self.param_shapes = dict()

        for i, (c_in, c_out) in enumerate(zip(c_dim_list[:-1], c_dim_list[1:])):
            if (i + 1) == self.n_layers:
                act_fn = nn.Identity()
                drop_rate = output_dropout
            else:
                act_fn = act_dict[act]
                drop_rate = dropout

            layers.append(basic_dense_block(c_in, c_out,
                                            activation=act_fn,
                                            drop_rate=drop_rate,# if i < n_layers - 1 else 0.0,
                                            use_batch_norm=batchnorm if i < self.n_layers - 1 else False,
                                            ))
            
            self.param_shapes[f'w{i}'] = (c_in, c_out)
            self.param_shapes[f'b{i}'] = (1, c_out)

        self.layers = nn.ModuleList(layers)
        self.has_l2norm = l2norm
        self.dims = dims
        self.output_dropout = output_dropout
        self.use_weight_norm = False

    def set_output_dims(self, dims):
        self.dims = dims

    def forward(self, x):
        # batch_size = x.size(0)
        # x = x.view(batch_size, -1)
        for layer in self.layers:
            x = layer(x)
            if self.has_l2norm:
                x = F.normalize(x, p=2, dim=1)
        if self.dims is not None:
            logits = x.view(x.size(0), *self.dims)
        else:
            logits = x

        return logits

    def stateless_forward(self, x, weights, biases):
        for i, layer in enumerate(self.layers):
            for j, module in enumerate(layer.children()):
                if j == 0:
                    x = F.linear(x, weights[i], biases[i])
                else:
                    x = module(x)
            if self.has_l2norm:
                x = F.normalize(x, p=2, dim=1)
        if self.dims is not None:
            logits = x.view(x.size(0), *self.dims)
        else:
            logits = x
        return logits


# def process_mlp_params(input_dim, output_dim, mlp_params, dims=None):
#     assert 'dim_inner' in mlp_params
#     assert 'num_layers' in mlp_params
#     hidden_dims = [input_dim]
#     dim_inner = mlp_params['dim_inner']
#     layers_encoder = mlp_params['num_layers']
#     hidden_dims.extend([dim_inner for i in range(layers_encoder)])
#     hidden_dims.append(output_dim)
#     mlp_params['hidden_dims'] = hidden_dims
#     mlp_params['dims'] = dims