import torch
from math import sqrt, isqrt
from hlb_utils import cosine_similarity
from torch.nn.functional import pad

class VTB(torch.nn.Module):
    def __init__(self, batch_size, d_model):
        super().__init__()
        self.d_model = d_model
        
        mask = torch.zeros((batch_size, d_model ** 2, d_model ** 2))
        ones = torch.ones((batch_size, d_model, d_model))

        for i in range(0, self.d_model * self.d_model, self.d_model):
            mask[:, i:i + self.d_model, i:i + self.d_model] += ones
            
        self.register_buffer("mask", mask)
        self.register_buffer("ones", ones)

    def block_diagonal(self, x, n):
        batch_size = x.size()[0]
        x = torch.tile(x, dims=[n, n])
        x = x * self.mask[0:batch_size]
        return x

    def bind_single_dim(self, x, y):
        d = int(x.size()[1])
        d_prime = int(isqrt(d))
        vy_prime = (d ** 0.25) * torch.reshape(y, (x.shape[0], d_prime, d_prime))
        vy = self.block_diagonal(vy_prime, d_prime)
        return torch.matmul(vy, x.unsqueeze(-1)).squeeze()

    def unbind_single_dim(self, x, y):
        d = int(x.size()[1])
        d_prime = int(isqrt(d))
        vy_prime = (d ** 0.25) * torch.reshape(y, (x.shape[0], d_prime, d_prime))
        vy = self.block_diagonal(vy_prime.permute(0, 2, 1), d_prime)
        return torch.matmul(vy, x.unsqueeze(-1)).squeeze()

    def binding(self, x, y, ch=1):
        org_d = x.shape[-1]
        if x.shape[-1] != self.d_model ** 2:
            p = (isqrt(org_d) + 1) ** 2 - org_d
            x = pad(x, (0, p))
            y = pad(y, (0, p))

        shape = x.shape
        d_prime = torch.sqrt(torch.tensor(shape[-1])).int()
        x = torch.reshape(x, (shape[0], -1, d_prime, d_prime))
        y = torch.reshape(y, (shape[0], -1, d_prime, d_prime))
        bind = torch.zeros(x.shape).to(x.device)
        size = (shape[0], self.d_model, self.d_model)
        for i in range(ch):
            bind[:, i, :, :] = self.bind_single_dim(x[:, i, :, :].flatten(1), y[:, i, :, :].flatten(1)).reshape(*size)
        bind = torch.reshape(bind, (shape[0], -1))
        return bind[:, :org_d]

    def unbinding(self, x, y, ch=1):
        org_d = x.shape[-1]
        if x.shape[-1] != self.d_model ** 2:
            p = (isqrt(org_d) + 1) ** 2 - org_d
            x = pad(x, (0, p))
            y = pad(y, (0, p))

        shape = x.shape
        d_prime = int(isqrt(shape[-1]))
        x = torch.reshape(x, (x.shape[0], -1, d_prime, d_prime))
        y = torch.reshape(y, (y.shape[0], -1, d_prime, d_prime))
        y = y.repeat(x.shape[0] - y.shape[0] + 1, 1, 1, 1)

        unbind = torch.zeros(x.shape).to(x.device)
        size = (shape[0], self.d_model, self.d_model)
        for i in range(ch):
            unbind[:, i, :, :] = self.unbind_single_dim(x[:, i, :, :].flatten(1), y[:, i, :, :].flatten(1)).reshape(
                *size)
        unbind = torch.reshape(unbind, (shape[0], -1))
        return unbind[:, :org_d]

class Orthogonal:
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def tensor(self, size):
        org_d = size[-1]
        if org_d != isqrt(org_d) ** 2:
            d_prime = int(sqrt(size[-1])) + 1
        else:
            d_prime = int(sqrt(size[-1]))
        size = (size[0], 1, d_prime, d_prime)
        random = torch.normal(mean=self.mean, std=self.std, size=size)
        q, _ = torch.linalg.qr(random)
        q = torch.reshape(q, (size[0], d_prime ** 2))
        return q[:, :org_d]

def vtb(batch_size, input_dim):
    if input_dim != isqrt(input_dim) ** 2:
        input_dim = (isqrt(input_dim) + 1) ** 2
    sampler = Orthogonal(mean=0., std=1. / input_dim)
    module = VTB(batch_size=batch_size, d_model=int(sqrt(input_dim)))
    return sampler, module
