import torch
import torch.nn as nn
import math


class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 2048):
        super().__init__()
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        pe = pe.permute(1, 0, 2)
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # print(x.shape, self.pe[:, :x.shape[1], :].shape)
        return self.pe[:, :x.shape[1], :]
    
class PositionalEmbedding(nn.Module):
    def __init__(self, d_model, max_len: int = 2048):
        super().__init__()
        self.embedding = nn.Embedding(max_len, d_model)
        self.d_model = d_model
        
    def forward(self, x):
        positions = torch.arange(x.shape[1]).long().to(x.device)
        return self.embedding(positions)

class FrequencyEncoding(nn.Module):
    def __init__(self, d_model, num_variables, max_freq=10000.0, base=10000.0):
        super().__init__()
        self.d_model = d_model
        self.num_variables = num_variables
        self.max_freq = max_freq
        self.base = base
        self.freq_bands = self._get_freq_bands()
        # self.vocab = {'start_token_id': start_token_id, 'end_token_id': end_token_id, 'pad_token_id': pad_token_id, 'sep_token_id': sep_token_id, 'number_token_id': number_token_id}   

    def _get_freq_bands(self):
        freq_bands = torch.zeros(self.d_model // (2 * (self.num_variables + 1)))
        for i in range(len(freq_bands)):
            freq_bands[i] = 1.0 / (self.base ** (2 * i / self.d_model))
        return freq_bands

    def forward(self, x):
        batch_size, seq_len, _ = x.shape
        encoding = []

        for i in range(self.num_variables + 1):
            xi = x[:, :, i].unsqueeze(-1)  # (batch_size, seq_len, 1)
            freqs = xi * self.freq_bands.unsqueeze(0).unsqueeze(0).to(x.device)
            enc_i = torch.cat([torch.sin(freqs), torch.cos(freqs)], dim=-1)
            encoding.append(enc_i)

        encoding = torch.cat(encoding, dim=-1)

        if encoding.shape[-1] < self.d_model:
            padding = torch.zeros(*encoding.shape[:-1], self.d_model - encoding.shape[-1]).to(x.device)
            encoding = torch.cat([encoding, padding], dim=-1)
        else:
            encoding = encoding[..., :self.d_model]

        return encoding

    def inverse(self, encoding):
        batch_size, seq_len, _ = encoding.shape
        restored = torch.zeros(batch_size, seq_len, self.num_variables + 1).to(encoding.device)

        enc_dim = self.freq_bands.shape[0] * 2

        for i in range(self.num_variables + 1):
            start = i * enc_dim
            end = (i + 1) * enc_dim
            enc_i = encoding[:, :, start:end]
            sin_part = enc_i[..., :enc_dim//2]
            cos_part = enc_i[..., enc_dim//2:enc_dim]
            
            angles = torch.atan2(sin_part, cos_part)
            restored[:, :, i] = angles[:, :, 0] / self.freq_bands[0]

        return restored

    def get_start_token(self):
        """Generate start token"""
        start_token =  torch.zeros(1 + self.num_variables)
        return start_token
    
class SubspaceEncoding(nn.Module):
    def __init__(self, d_model, num_variables, max_freq=10000.0, base=10000.0):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(num_variables + 1, d_model))
        self.d_model = d_model

    def forward(self, x):
        
        encoding = x @ self.weight
        return encoding

    def inverse(self, encoding):
        normalized_weight = self.weight / self.weight.norm(dim=-1, keepdim=True)
        restored = encoding @ normalized_weight.T
        return restored

    def get_start_token(self):
        start_token =  torch.zeros(1 + self.num_variables)
        return start_token
    

class FrequencyDecoder(nn.Module):
    def __init__(self, d_model, num_variables, max_freq=10000.0, base=10000.0):
        super().__init__()
        self.d_model = d_model
        self.num_variables = num_variables
        self.max_freq = max_freq
        self.base = base
        self.freq_bands = self._get_freq_bands()
        self.fc = nn.Linear(d_model, 1 + num_variables)

    def _get_freq_bands(self):
        freq_bands = torch.zeros(self.d_model // (2 * (self.num_variables + 1)))
        for i in range(len(freq_bands)):
            freq_bands[i] = 1.0 / (self.base ** (2 * i / self.d_model))
        return freq_bands

    def forward(self, x):
        decoded = self.fc(x)
        
        restored = torch.zeros_like(decoded)
        enc_dim = self.d_model // (self.num_variables + 1)
        for i in range(self.num_variables + 1):
            start = i * enc_dim
            end = (i + 1) * enc_dim
            enc_i = x[:, :, start:end]
            sin_part = enc_i[..., :enc_dim//2]
            cos_part = enc_i[..., enc_dim//2:]
            
            angles = torch.atan2(sin_part, cos_part)
            restored[:, :, i] = angles[:, :, 0] / self.freq_bands[0]
        
        return restored
    
class PolynomialEncoding(nn.Module):
    def __init__(self, d_model, num_variables, max_degree=10):
        super().__init__()
        self.d_model = d_model
        self.num_variables = num_variables
        self.max_degree = max_degree


    def forward(self, x):
        batch_size, seq_len, _ = x.shape
        degrees = torch.arange(self.max_degree).float().to(x.device)
        encoding = []

        for i in range(self.num_variables + 1):
            xi = x[:, :, i].unsqueeze(-1)
            enc_i = xi.pow(degrees)
            enc_i = enc_i / (enc_i.abs().max(dim=-1, keepdim=True)[0] + 1e-8)
            encoding.append(enc_i)

        encoding = torch.cat(encoding, dim=-1)

        if encoding.shape[-1] < self.d_model:
            padding = torch.zeros(*encoding.shape[:-1], self.d_model - encoding.shape[-1]).to(x.device)
            encoding = torch.cat([encoding, padding], dim=-1)
        else:
            encoding = encoding[..., :self.d_model]

        return encoding

    def inverse(self, encoding):
        batch_size, seq_len, _ = encoding.shape
        restored = torch.zeros(batch_size, seq_len, self.num_variables + 1).to(encoding.device)

        enc_dim = self.max_degree
        for i in range(self.num_variables + 1):
            start = i * enc_dim
            end = (i + 1) * enc_dim
            enc_i = encoding[:, :, start:end]
            restored[:, :, i] = enc_i[:, :, 1] / (enc_i[:, :, 0] + 1e-8)

        return restored

    def get_start_token(self):
        start_token =  torch.zeros(1 + self.num_variables)
        return start_token


class PolynomialDecoder(nn.Module):
    def __init__(self, d_model, num_variables, max_degree=10):
        super().__init__()
        self.d_model = d_model
        self.num_variables = num_variables
        self.max_degree = max_degree
        self.fc = nn.Linear(d_model, (1 + num_variables) * max_degree)

    def forward(self, x):
        coefficients = self.fc(x)
        coefficients = coefficients.view(*coefficients.shape[:-1], 1 + self.num_variables, self.max_degree)
        return coefficients.sum(dim=-1)

class GaussianEncoding(nn.Module):
    def __init__(self, d_model, num_variables, max_value=10, sigma=1.0):
        super().__init__()
        self.d_model = d_model
        self.num_variables = num_variables
        self.num_gaussians = d_model // (num_variables + 1)
        self.sigma = sigma
        self.means = nn.Parameter(torch.linspace(-max_value, max_value, self.num_gaussians))

    def forward(self, x):
        batch_size, seq_len, _ = x.shape
        encoding = []

        for i in range(self.num_variables + 1):
            xi = x[:, :, i].unsqueeze(-1)
            enc_i = torch.exp(-0.5 * ((xi - self.means) / self.sigma)**2)
            encoding.append(enc_i)

        encoding = torch.cat(encoding, dim=-1)

        if encoding.shape[-1] < self.d_model:
            padding = torch.zeros(*encoding.shape[:-1], self.d_model - encoding.shape[-1]).to(x.device)
            encoding = torch.cat([encoding, padding], dim=-1)
        else:
            encoding = encoding[..., :self.d_model]

        return encoding

    def inverse(self, encoding):
        batch_size, seq_len, _ = encoding.shape
        restored = torch.zeros(batch_size, seq_len, self.num_variables + 1).to(encoding.device)

        # enc_dim = self.d_model // (self.num_variables + 1)
        enc_dim = self.num_gaussians
        for i in range(self.num_variables + 1):
            start = i * enc_dim
            end = (i + 1) * enc_dim
            enc_i = encoding[:, :, start:end]
            max_indices = enc_i.argmax(dim=-1)
            restored[:, :, i] = self.means[max_indices]

        return restored

    def get_start_token(self):
        start_token =  torch.zeros(1 + self.num_variables)
        return start_token

class GaussianDecoder(nn.Module):
    def __init__(self, d_model, num_variables, num_gaussians=64):
        super().__init__()
        self.d_model = d_model
        self.num_variables = num_variables
        self.num_gaussians = num_gaussians
        self.fc = nn.Linear(d_model, (1 + num_variables) * num_gaussians)
        self.means = nn.Parameter(torch.linspace(-10, 10, num_gaussians))

    def forward(self, x):
        weights = self.fc(x)
        weights = weights.view(*weights.shape[:-1], 1 + self.num_variables, self.num_gaussians)
        decoded = (weights * self.means).sum(dim=-1)
        return decoded

class HybridEncoding(nn.Module):
    def __init__(self, d_model, num_variables, max_freq=10000.0, max_degree=5, num_gaussians=32):
        super().__init__()
        self.d_model = d_model
        self.num_variables = num_variables
        self.freq_encoding = FrequencyEncoding(d_model // 3, num_variables, max_freq)
        self.poly_encoding = PolynomialEncoding(d_model // 3, num_variables, max_degree)
        self.gauss_encoding = GaussianEncoding(d_model - 2 * (d_model // 3), num_variables, num_gaussians)

    def forward(self, x):
        freq_enc = self.freq_encoding(x)
        poly_enc = self.poly_encoding(x)
        gauss_enc = self.gauss_encoding(x)
        
        return torch.cat([freq_enc, poly_enc, gauss_enc], dim=-1)

    def inverse(self, encoding):
        freq_part = encoding[..., :self.d_model//3]
        poly_part = encoding[..., self.d_model//3:2*(self.d_model//3)]
        gauss_part = encoding[..., 2*(self.d_model//3):]
        
        freq_inv = self.freq_encoding.inverse(freq_part)
        poly_inv = self.poly_encoding.inverse(poly_part)
        gauss_inv = self.gauss_encoding.inverse(gauss_part)
        
        return (freq_inv + poly_inv + gauss_inv) / 3

    def get_start_token(self):
        start_token =  torch.zeros(1 + self.num_variables)
        return start_token

class HybridDecoder(nn.Module):
    def __init__(self, d_model, num_variables, max_degree=5, num_gaussians=32):
        super().__init__()
        self.d_model = d_model
        self.num_variables = num_variables
        self.freq_decoder = FrequencyDecoder(d_model // 3, num_variables)
        self.poly_decoder = PolynomialDecoder(d_model // 3, num_variables, max_degree)
        self.gauss_decoder = GaussianDecoder(d_model - 2 * (d_model // 3), num_variables, num_gaussians)
        self.combine = nn.Linear(3, 1)

    def forward(self, x):
        freq_part = x[..., :self.d_model//3]
        poly_part = x[..., self.d_model//3:2*(self.d_model//3)]
        gauss_part = x[..., 2*(self.d_model//3):]

        freq_decoded = self.freq_decoder(freq_part)
        poly_decoded = self.poly_decoder(poly_part)
        gauss_decoded = self.gauss_decoder(gauss_part)

        combined = torch.stack([freq_decoded, poly_decoded, gauss_decoded], dim=-1)
        decoded = self.combine(combined).squeeze(-2)
        return decoded