import logging
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn

import ipdb
import math
from einops import rearrange, repeat
from itertools import zip_longest


class sgpt(torch.nn.Module):
    def __init__(self, d, hid_dim, output_dim=1, n_layers=2, n_heads = 1, read_in_fixed = True):
        super(sgpt, self).__init__()


        self.n_heads = n_heads
        self.random_project1 = nn.Linear(d, hid_dim, bias=False)
        fix_weights_1 = (1 / math.sqrt(hid_dim)) * torch.randn((hid_dim, d)).to('cuda')
        self.random_project1.weight = nn.Parameter(fix_weights_1)

        if read_in_fixed:
            for param in self.random_project1.parameters():
                param.requires_grad = False

        self.c_projs = nn.ModuleList()

        self.mlps = nn.ModuleList()
        for i in range(n_layers):
            self.mlps.append(
                nn.Sequential(
                    nn.Linear(hid_dim, hid_dim, bias=True),
                    nn.GELU(),
                )
            )
            self.c_projs.append(nn.Linear(hid_dim, hid_dim))


        self.hid_dim = hid_dim


        self._read_out = nn.Linear(hid_dim, output_dim, bias=True)


    @staticmethod
    def _combine(xs_b, ys_b):
        """
        Directly stack the x's and y's into the same location
        resulting sequence would be Bx(N+1)x(d+1), where (N+1)-th token is query
        """
        zs = torch.cat((xs_b, ys_b.unsqueeze(2)), dim=2)
        zs[:, -1, -1].zero_()
        return zs

    def forward(self, x, y):

        A = self._combine(x, y)

        H = self.random_project1(A)

        B, N, k = H.size()


        layers = 0
        for mlp, c_proj in zip(self.mlps, self.c_projs):


            H_reshaped = H.view(B, N, self.n_heads, k // self.n_heads).permute(0, 2, 1, 3)

            attn_weights = torch.matmul(H_reshaped, H_reshaped.transpose(-1, -2))
            attn_weights = attn_weights / torch.full(
                [], self.hid_dim ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
            )

            attn_weights = torch.nn.functional.normalize(attn_weights, p=1, dim=-1)
            attn_output = torch.matmul(attn_weights, H_reshaped)

            attn_output = rearrange(attn_output, 'b h n d -> b n (h d)')
            attn_output = c_proj(attn_output)
            hidden_states = attn_output + H


            feed_forward_hidden_states = mlp(hidden_states)
            H = feed_forward_hidden_states + hidden_states

            layers += 1

        prediction = self._read_out(H)

        return prediction[:, -1]

