import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from .cola_nn import dense_init
import torch



class CappedList():
    def __init__(self, max_len=1):
        self.max_len = max_len
        self.buffer = []

    def append(self, x):
        if len(self.buffer) < self.max_len:
            self.buffer.append(x.cpu())


class MLPBlockFact(nn.Module):
    def __init__(
        self,
        width,
        residual,
        ln_width,
        linear_layer_fn,
        layer_norm,
        ffn_expansion,
        residual_mult,
        use_bias,
    ):
        super().__init__()
        self.linear1 = linear_layer_fn(width, width * ffn_expansion, bias=use_bias)
        self.linear2 = linear_layer_fn(width * ffn_expansion, width, bias=use_bias)
        self.residual_mult = residual_mult
        dense_init(self.linear1)
        dense_init(self.linear2, zero_init=True)
        self.residual = residual
        self.layer_norm = layer_norm
        if layer_norm:
            self.ln = nn.LayerNorm(ln_width)
        self.ln_dim = self.linear2.out_features

    def forward(self, x):
        x0 = x
        if self.layer_norm:
            x = self.ln(x)
        x = F.gelu(self.linear1(x))
        x = self.linear2(x) * self.residual_mult
        if self.residual:
            x = x + x0
        return x


class MLPFact(nn.Module):
    def __init__(self, dim_in, dim_out, depth, width, fact_cls, ffn_expansion, residual=True, layer_norm=True,
                 shuffle_pixels=True, attn_mult=1, output_mult=1, emb_mult=1, use_bias=True, downsample_image=-1, **_):
        super().__init__()
        self.shuffle_pixels = shuffle_pixels
        local_rng = np.random.default_rng(42)
        # Shuffle the pixels using the local random generator
        self.pixel_indices = local_rng.permutation(dim_in)
        self.output_mult = output_mult
        self.emb_mult = emb_mult
        # input layer
        if hasattr(fact_cls, "Linear"):
            linear_layer_fn = fact_cls.Linear
        else:
            linear_layer_fn = nn.Linear
        self.input_layer = linear_layer_fn(dim_in, width, bias=use_bias)
        dense_init(self.input_layer)
        # self.input_layer.weight.data *= 2**0.5  # relu gain
        # hidden layers
        self.hidden_layers = nn.ModuleList()
        aux = self.input_layer.out_features
        for _ in range(depth):
            module = MLPBlockFact(width, residual, aux, linear_layer_fn, layer_norm, ffn_expansion=ffn_expansion,
                                  residual_mult=attn_mult, use_bias=use_bias)

            aux = module.ln_dim
            self.hidden_layers.append(module)
        # output layer
        self.output_layer = nn.Linear(aux, dim_out, bias=use_bias)
        dense_init(self.output_layer, zero_init=True)
        # logs
        self.hs = [CappedList() for _ in range(depth + 2)]

    def forward(self, x):
        x = x.reshape(x.shape[0], -1)
        if self.shuffle_pixels:
            x = x[:, self.pixel_indices]
        x = F.gelu(self.input_layer(x) * self.emb_mult)
        if not self.training:
            self.hs[0].append(x.detach())
        for i, layer in enumerate(self.hidden_layers):
            x = layer(x)
            if not self.training:
                self.hs[i + 1].append(x.detach())
        y = self.output_layer(x) * self.output_mult
        if not self.training:
            self.hs[-1].append(y.detach())
        return y

    def get_features(self):
        return self.hs

    def clear_features(self):
        self.hs = [CappedList() for _ in range(len(self.hs))]


class MLPB(nn.Module):
    def __init__(self, width, residual, layer_norm, use_bias=True, **kwargs):
        super().__init__()
        self.linear = nn.Linear(width, width, bias=use_bias)
        self.residual = residual
        self.layer_norm = layer_norm
        if layer_norm:
            self.ln = nn.LayerNorm(width)

    def forward(self, x):
        x0 = x
        if self.layer_norm:
            x = self.ln(x)
        x = F.relu(self.linear(x))
        if self.residual:
            x = x + x0
        return x


class MLPNoEx(nn.Module):
    def __init__(self, dim_in, dim_out, depth, width, use_bias=True, residual=True, layer_norm=True, shuffle_pixels=True,
                 **kwargs):
        super().__init__()
        self.shuffle_pixels = shuffle_pixels
        local_rng = np.random.default_rng(42)
        self.pixel_indices = local_rng.permutation(dim_in)
        self.input_layer = nn.Linear(dim_in, width, bias=use_bias)
        self.hidden_layers = nn.ModuleList()
        for _ in range(depth):
            self.hidden_layers.append(MLPB(width, residual, layer_norm, use_bias, **kwargs))
        self.output_layer = nn.Linear(width, dim_out)
        self.hs = [CappedList() for _ in range(depth + 2)]

    def forward(self, x):
        # x = x.reshape(x.shape[0], -1)
        # if self.shuffle_pixels:
        #     x = x[:, self.pixel_indices]
        # x = F.relu(self.input_layer(x))
        # for layer in self.hidden_layers:
        #     x = layer(x)
        # return self.output_layer(x)

        x = x.reshape(x.shape[0], -1)
        if self.shuffle_pixels:
            x = x[:, self.pixel_indices]
        x = F.relu(self.input_layer(x))
        if not self.training:
            self.hs[0].append(x.detach())
        for i, layer in enumerate(self.hidden_layers):
            x = layer(x)
            if not self.training:
                self.hs[i + 1].append(x.detach())
        y = self.output_layer(x)
        if not self.training:
            self.hs[-1].append(y.detach())
        return y

    def get_features(self):
        return self.hs

    def clear_features(self):
        self.hs = [CappedList() for _ in range(len(self.hs))]


class MLPBlock(nn.Module):
    def __init__(self, width, residual, layer_norm, residual_mult=1, use_bias=True):
        super().__init__()
        self.linear1 = nn.Linear(width, width * 4, bias=use_bias)
        self.linear2 = nn.Linear(width * 4, width, bias=use_bias)
        self.residual_mult = residual_mult
        dense_init(self.linear1)
        dense_init(self.linear2, zero_init=True)
        self.residual = residual
        self.layer_norm = layer_norm
        if layer_norm:
            self.ln = nn.LayerNorm(width)
        self.gelu = nn.GELU()

    def forward(self, x):
        x0 = x
        if self.layer_norm:
            x = self.ln(x)
        x = self.gelu(self.linear1(x))
        x = self.linear2(x) * self.residual_mult
        if self.residual:
            x = x + x0
        return x


class MLP(nn.Module):
    def __init__(self, dim_in, dim_out, depth, width, residual=True, layer_norm=True, shuffle_pixels=True, attn_mult=1,
                 output_mult=1, emb_mult=1, use_bias=True, downsample_image=-1, **_):
        super().__init__()
        self.shuffle_pixels = shuffle_pixels
        local_rng = np.random.default_rng(42)
        # Shuffle the pixels using the local random generator
        self.pixel_indices = local_rng.permutation(dim_in)
        self.output_mult = output_mult
        self.emb_mult = emb_mult
        # input layer
        self.input_layer = nn.Linear(dim_in, width, bias=use_bias)
        # input LN
        self.input_ln = nn.LayerNorm(width)
        dense_init(self.input_layer)
        # self.input_layer.weight.data *= 2**0.5  # relu gain
        # hidden layers
        self.hidden_layers = nn.ModuleList()
        for _ in range(depth):
            self.hidden_layers.append(MLPBlock(width, residual, layer_norm, residual_mult=attn_mult, use_bias=use_bias))
        # output layer
        self.output_layer = nn.Linear(width, dim_out, bias=use_bias)
        dense_init(self.output_layer, zero_init=True)
        # logs
        self.hs = [CappedList() for _ in range(depth + 2)]

    def forward(self, x):
        x = x.reshape(x.shape[0], -1)
        if self.shuffle_pixels:
            x = x[:, self.pixel_indices]
        x = F.gelu(self.input_layer(x) * self.emb_mult)
        x = self.input_ln(x)
        if not self.training:
            self.hs[0].append(x.detach())
        for i, layer in enumerate(self.hidden_layers):
            x = layer(x)
            if not self.training:
                self.hs[i + 1].append(x.detach())
        y = self.output_layer(x) * self.output_mult
        if not self.training:
            self.hs[-1].append(y.detach())
        return y

    def get_features(self):
        return self.hs

    def clear_features(self):
        self.hs = [CappedList() for _ in range(len(self.hs))]


class DenseNetBlock(nn.Module):
    def __init__(self, width, layer_norm, use_bias=True, block_depth=2, width_factor=1):
        super().__init__()

        self.layer_norm = layer_norm
        self.layer_norms = []
        self.layers = []
        self.block_depth = block_depth

        in_dim = width
        out_dim = width * width_factor

        for idx in range(self.block_depth-1):
            if self.layer_norm:
                self.layer_norms.append(nn.LayerNorm(in_dim))
            self.layers.append(nn.Linear(in_dim, out_dim, bias=use_bias))
            dense_init(self.layers[-1], zero_init=True)
            in_dim += out_dim

        if self.layer_norm:
            self.layer_norms.append(nn.LayerNorm(in_dim))
        self.linear_last = nn.Linear(in_dim, width, bias=use_bias)
        dense_init(self.linear_last, zero_init=True)
        
        self.gelu = nn.GELU()

    def forward(self, x):
        inputs_list = [x]
        for idx in range(self.block_depth-1):
            inputs = torch.cat(inputs_list, 1)
            if self.layer_norm:
                inputs = self.layer_norms[idx](inputs)
            inputs_list.append(self.gelu(self.layers[idx](inputs)))
       
        inputs = torch.cat(inputs_list, 1)
        if self.layer_norm:
            inputs = self.layer_norms[-1](inputs)
        inputs = self.linear_last(inputs)
        #return self.gelu(inputs)
        return inputs

class DenseNet(nn.Module):
    def __init__(self, dim_in, dim_out, depth, width, layer_norm=True, shuffle_pixels=True, attn_mult=1,
                 output_mult=1, emb_mult=1, use_bias=True, downsample_image=-1, **_):
        super().__init__()
        self.shuffle_pixels = shuffle_pixels
        local_rng = np.random.default_rng(42)
        # Shuffle the pixels using the local random generator
        self.pixel_indices = local_rng.permutation(dim_in)
        self.output_mult = output_mult
        self.emb_mult = emb_mult
        # input layer
        self.input_layer = nn.Linear(dim_in, width, bias=use_bias)
        # input LN
        self.input_ln = nn.LayerNorm(width)
        dense_init(self.input_layer)
        # self.input_layer.weight.data *= 2**0.5  # relu gain
        # hidden layers
        self.hidden_layers = nn.ModuleList()
        for _ in range(depth):
            self.hidden_layers.append(DenseNetBlock(width, layer_norm, use_bias=use_bias))
        # output layer
        self.output_layer = nn.Linear(width, dim_out, bias=use_bias)
        dense_init(self.output_layer, zero_init=True)
        # logs
        self.hs = [CappedList() for _ in range(depth + 2)]

    def forward(self, x):
        x = x.reshape(x.shape[0], -1)
        if self.shuffle_pixels:
            x = x[:, self.pixel_indices]
        x = F.gelu(self.input_layer(x) * self.emb_mult)
        x = self.input_ln(x)
        if not self.training:
            self.hs[0].append(x.detach())
        for i, layer in enumerate(self.hidden_layers):
            x = layer(x)
            if not self.training:
                self.hs[i + 1].append(x.detach())
        y = self.output_layer(x) * self.output_mult
        if not self.training:
            self.hs[-1].append(y.detach())
        return y

    def get_features(self):
        return self.hs

    def clear_features(self):
        self.hs = [CappedList() for _ in range(len(self.hs))]



class CoLAMLPBlock(nn.Module):
    def __init__(self, width, residual, layer_norm, builder_fn, **kwargs):
        super().__init__()
        self.linear = builder_fn(width, width, **kwargs)
        self.residual = residual
        self.layer_norm = layer_norm
        if layer_norm:
            self.ln = nn.LayerNorm(width)

    def forward(self, x):
        x0 = x
        if self.layer_norm:
            x = self.ln(x)
        x = F.relu(self.linear(x))
        if self.residual:
            x = x + x0
        return x


class CoLAMLP(nn.Module):
    def __init__(self, dim_in, dim_out, depth, width, builder_fn, residual=True, layer_norm=True, shuffle_pixels=True, **kwargs):
        super().__init__()
        self.shuffle_pixels = shuffle_pixels
        self.input_layer = builder_fn(dim_in, width, **kwargs)
        self.hidden_layers = nn.ModuleList()
        for _ in range(depth):
            self.hidden_layers.append(CoLAMLPBlock(width, residual, layer_norm, builder_fn, **kwargs))
        self.output_layer = nn.Linear(width, dim_out)

    def forward(self, x):
        x = x.reshape(x.shape[0], -1)
        if self.shuffle_pixels:
            x = x[:, self.pixel_indices]
        x = F.relu(self.input_layer(x))
        for layer in self.hidden_layers:
            x = layer(x)
        return self.output_layer(x)
