import torch
from torch import nn
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from .cola_nn import dense_init

# helpers
class CappedList():
    def __init__(self, max_len=1):
        self.max_len = max_len
        self.buffer = []

    def append(self, x):
        if len(self.buffer) < self.max_len:
            self.buffer.append(x.cpu())

def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# classes
class RNNBlock(nn.Module):
    def __init__(self, hidden_size, nonlinearity='relu'):
        super().__init__()
        self.hidden_size = hidden_size
        self.nonlinearity = getattr(torch, nonlinearity)
        self.linear1 = nn.Linear(hidden_size, hidden_size)
        self.linear2 = nn.Linear(hidden_size, hidden_size)
        self.gate = nn.Linear(hidden_size, hidden_size)
        self.norm = nn.LayerNorm(hidden_size)
        
        # zero init weights
        dense_init(self.linear2, zero_init=True)
        dense_init(self.gate, zero_init=True)

    def forward(self, x, z):
        dz = self.nonlinearity(self.linear1(x + self.norm(z)))
        dz = self.linear2(dz)
        g = torch.sigmoid(self.gate(z))
        dz = g * dz
        z = z + dz
        return z

class RNN(nn.Module):
    def __init__(self, dim_out, width, image_size=32, patch_size=8, in_channels=3, dropout=0., output_mult=1, emb_mult=1, use_bias=True, **kwargs):
        super().__init__()
        self.emb_mult = emb_mult
        self.output_mult = output_mult

        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Invalid patch size'

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = in_channels * patch_height * patch_width

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_height, p2=patch_width),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, width, bias=use_bias),
            nn.LayerNorm(width),
        )

        self.rnn = RNNBlock(hidden_size=width, nonlinearity='relu')
        self.dropout = nn.Dropout(dropout)

        self.mlp_head = nn.Linear(width, dim_out, bias=use_bias)

        # fix init
        self.fix_init()

        # logs
        self.hs = [CappedList() for _ in range(2)]

    def fix_init(self):
        # go through all linear layers
        for m in self.modules():
            # skip zero init layers
            if isinstance(m, nn.Linear):
                if hasattr(m.weight, 'zero_init') and m.weight.zero_init:
                    print(f"Skipping zero init: {m}")
                    continue
                else:
                    print(f"Fixing init: {m}")
                    dense_init(m)
        print(f"Fixing output init: {self.mlp_head}")
        dense_init(self.mlp_head, zero_init=True)

    def forward(self, img):
        x = self.to_patch_embedding(img)
        x = x * self.emb_mult
        x = self.dropout(x)

        if not self.training:
            self.hs[0].append(x.detach())

        hidden = x[:, 0]
        for i in range(1, x.size(1)):
            hidden = self.rnn(x[:, i], hidden)

        y = self.mlp_head(hidden) * self.output_mult
        if not self.training:
            self.hs[1].append(y.detach())
        return y

    def get_features(self):
        return self.hs

    def clear_features(self):
        self.hs = [CappedList() for _ in range(2)]