import torch
import torch.nn as nn
import numpy as np
from math import sqrt
import torch.nn.functional as F
from layers.complex_func import ComplexLinear, ComplexLayerNorm


class TriangularCausalMask:
    """Standard causal mask extended to complex attention workloads."""
    def __init__(self, B, L, device="cpu"):
        mask_shape = [B, 1, L, L]
        with torch.no_grad():
            self._mask = torch.triu(
                torch.ones(mask_shape, dtype=torch.bool), diagonal=1
            ).to(device)

    @property
    def mask(self):
        return self._mask


class ComplexDropout(nn.Module):
    """Apply dropout to real and imaginary parts independently."""
    def __init__(self, p=0.1):
        super(ComplexDropout, self).__init__()
        self.dropout_real = nn.Dropout(p)
        self.dropout_imag = nn.Dropout(p)

    def forward(self, x):
        real = self.dropout_real(x.real)
        imag = self.dropout_imag(x.imag)
        return torch.complex(real, imag)


class ComplexFullAttention(nn.Module):
    """Scaled dot-product attention that preserves complex phases."""
    def __init__(
        self, mask_flag=True, scale=None, attention_dropout=0.1, output_attention=False
    ):
        super(ComplexFullAttention, self).__init__()
        self.scale = scale
        self.mask_flag = mask_flag
        self.output_attention = output_attention
        self.dropout = ComplexDropout(attention_dropout)

    def forward(self, queries, keys, values, attn_mask=None, tau=None, delta=None):
        """Compute complex attention and optionally return normalized scores."""
        B, L, H, E = queries.shape
        _, S, _, D = values.shape
        scale = self.scale or 1.0 / sqrt(E)

        scores = torch.einsum("blhe,bshe->bhls", queries, torch.conj(keys))

        scores = scores * scale

        epsilon = 1e-8
        r = torch.abs(scores)

        if self.mask_flag:
            if attn_mask is None:
                attn_mask = TriangularCausalMask(B, L, device=queries.device)

            mask_tensor = attn_mask.mask
            if mask_tensor.dtype == torch.bool:
                r = r.masked_fill(mask_tensor, -np.inf)
            else:
                r = r + mask_tensor

        w = torch.softmax(r, dim=-1)

        u = scores / (torch.abs(scores) + epsilon)

        normalized_scores = w.to(dtype=scores.dtype) * u

        normalized_scores = self.dropout(normalized_scores)

        V = torch.einsum("bhls,bshd->blhd", normalized_scores, values)

        if self.output_attention:
            return (V.contiguous(), normalized_scores)
        else:
            return (V.contiguous(), None)


class ComplexFullAttentionLayer(nn.Module):
    """Full attention block with complex projections around the kernel."""
    def __init__(self, attention, d_model, n_heads, d_keys=None, d_values=None):
        super(ComplexFullAttentionLayer, self).__init__()
        d_keys = d_keys or (d_model // n_heads)
        d_values = d_values or (d_model // n_heads)
        self.inner_attention = attention
        self.query_projection = ComplexLinear(d_model, d_keys * n_heads)
        self.key_projection = ComplexLinear(d_model, d_keys * n_heads)
        self.value_projection = ComplexLinear(d_model, d_values * n_heads)
        self.out_projection = ComplexLinear(d_values * n_heads, d_model)
        self.n_heads = n_heads

    def forward(self, queries, keys, values, attn_mask=None, tau=None, delta=None):
        """Project inputs to multi-head complex space and apply attention."""
        B, L, _ = queries.shape
        _, S, _ = keys.shape
        H = self.n_heads

        queries = self.query_projection(queries).view(B, L, H, -1)
        keys = self.key_projection(keys).view(B, S, H, -1)
        values = self.value_projection(values).view(B, S, H, -1)

        out, attn = self.inner_attention(
            queries, keys, values, attn_mask, tau, delta)

        out = out.view(B, L, -1)
        out = self.out_projection(out)

        return out, attn


class ComplexConv1d(nn.Module):
    """Complex-valued 1D convolution implemented via real/imag pairs."""
    def __init__(self, in_channels, out_channels, kernel_size=1, bias=True):
        super(ComplexConv1d, self).__init__()
        self.conv_real = nn.Conv1d(
            in_channels, out_channels, kernel_size, bias=bias)
        self.conv_imag = nn.Conv1d(
            in_channels, out_channels, kernel_size, bias=bias)

    def forward(self, x):
        real = self.conv_real(x.real) - self.conv_imag(x.imag)
        imag = self.conv_imag(x.real) + self.conv_real(x.imag)
        return torch.complex(real, imag)


def complex_relu(inp):
    return torch.complex(F.relu(inp.real), F.relu(inp.imag))


def complex_gelu(inp):
    return torch.complex(F.gelu(inp.real), F.gelu(inp.imag))


class EncoderLayer(nn.Module):
    """Standard Transformer encoder layer operating on complex features."""
    def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"):
        super(EncoderLayer, self).__init__()
        d_ff = d_ff or 4 * d_model
        self.attention = attention
        self.conv1 = ComplexConv1d(
            in_channels=d_model, out_channels=d_ff, kernel_size=1
        )
        self.conv2 = ComplexConv1d(
            in_channels=d_ff, out_channels=d_model, kernel_size=1
        )
        self.norm1 = ComplexLayerNorm(d_model)
        self.norm2 = ComplexLayerNorm(d_model)
        self.dropout = ComplexDropout(dropout)
        self.activation = complex_relu if activation == "relu" else complex_gelu

    def forward(self, x, attn_mask=None, tau=None, delta=None):
        """Run complex self-attention followed by point-wise feed-forward."""
        new_x, attn = self.attention(
            x, x, x, attn_mask=attn_mask, tau=tau, delta=delta)
        x = x + self.dropout(new_x)
        y = x = self.norm1(x)
        y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
        y = self.dropout(self.conv2(y).transpose(-1, 1))
        return self.norm2(x + y), attn


class Encoder(nn.Module):
    """Stack of complex attention layers with optional convolutional skips."""
    def __init__(self, attn_layers, conv_layers=None, norm_layer=None):
        super(Encoder, self).__init__()
        self.attn_layers = nn.ModuleList(attn_layers)
        self.conv_layers = (
            nn.ModuleList(conv_layers) if conv_layers is not None else None
        )
        self.norm = (
            ComplexLayerNorm(
                norm_layer.normalized_shape[0]) if norm_layer else None
        )

    def forward(self, x, attn_mask=None, tau=None, delta=None):
        """Apply each encoder block and collect attention diagnostics."""
        attns = []
        if self.conv_layers is not None:
            for i, (attn_layer, conv_layer) in enumerate(
                zip(self.attn_layers, self.conv_layers)
            ):
                delta = delta if i == 0 else None
                x, attn = attn_layer(
                    x, attn_mask=attn_mask, tau=tau, delta=delta)
                x = conv_layer(x)
                attns.append(attn)
            x, attn = self.attn_layers[-1](x, tau=tau, delta=None)
            attns.append(attn)
        else:
            for attn_layer in self.attn_layers:
                x, attn = attn_layer(
                    x, attn_mask=attn_mask, tau=tau, delta=delta)
                attns.append(attn)
        if self.norm is not None:
            x = self.norm(x)
        return x, attns
