import warnings

import torch
from torch.nn import functional as F

from utils import recorder

_name = "selective_scan"  # for recording


def complex_log(input, eps=1e-12):
    eps = input.new_tensor(eps)
    real = input.abs().maximum(eps).log()
    imag = (input < 0).to(input.dtype) * torch.pi
    return torch.complex(real, imag)


def selective_scan(x, dt, A, B, C, D, config: dict, mode='logcumsumexp'):

    A = A.to(x.device)
    y_ssm_from_attention = None

    ssm_mode = config.get('ssm_mode', 'selective_scan')

    if config.get('compute_attention', False) or ssm_mode == "attention":

        # TODO
        # if config.get('init_A', None) != 1:
        #     warnings.warn("note that attention currently assumes fixed A[i,j]=1")

        attention_alpha, y_ssm_from_attention = compute_attention(x, A, B, C)

        recorder.record(_name, attention_alpha=attention_alpha)

    match ssm_mode:

        case "selective_scan":

            # discretize
            
            dB = torch.einsum('bld,bln->bldn', dt, B)
            dA = torch.einsum('bld,dn->bldn', dt, A)

            dB_x = torch.einsum('bld,bldn->bldn', x, dB)
            dA = dA.clamp(min=-20)

            recorder.record(_name, B_bar=dB, B_bar_x=dB_x, A_bar=dA)

            padding = (0, 0, 0, 0, 1, 0)  # never change this
            dA_star = F.pad(dA[:, 1:], padding).cumsum(1)

            match mode:

                case 'cumsum':

                    dA_cumsum = dA_star.exp()
                    h_ssm = dB_x / (dA_cumsum + 1e-12)
                    h_ssm = h_ssm.cumsum(1) * dA_cumsum
                    y_ssm_raw = torch.einsum('bldn,bln->bld', h_ssm, C)

                    recorder.record(_name, h_ssm=h_ssm)

                case 'logcumsumexp':  # more numerically stable (Heisen sequence)

                    dB_x_log = complex_log(dB_x)
                    h_log = torch.logcumsumexp(dB_x_log - dA_star, 1) + dA_star
                    y_ssm_raw = torch.einsum('bldn,bln->bld', h_log.real.exp() * torch.cos(h_log.imag), C)

                    recorder.record(_name, h_log_ssm=h_log)

                case _:
                    raise ValueError(f"unknown {mode = }")

            recorder.record(_name, y_ssm_scan=y_ssm_raw)

        case "attention":

                if y_ssm_from_attention is None:
                    attention_alpha, y_ssm_from_attention = compute_attention(x, A, B, C)

                y_ssm_raw = y_ssm_from_attention

        case "recurrent":

            h_ssm, y_ssm_raw = compute_recurrence(x, A, B, C)
            recorder.record(_name, y_ssm_recurrent=y_ssm_raw, h_ssm=h_ssm)

        case _:
            raise ValueError(f"unknown {config['ssm_mode'] = }")

    # output
    y_ssm_plus_u_D = y_ssm_raw + x * D
    recorder.record(_name, y_ssm_plus_u_D=y_ssm_plus_u_D)

    return y_ssm_plus_u_D


def compute_recurrence(x, A, B, C):

    device = x.device

    # shapes
    batch_size, L, D_in = x.shape
    N = B.shape[2]

    # pre-allocate
    h = torch.zeros((batch_size, L, N, D_in), device=device)
    y = torch.zeros((batch_size, L, D_in), device=device)

    # carry h_t across time
    h_t = torch.zeros((batch_size, N, D_in), device=device)

    for t in range(L):
        # slice out time‑t for the whole batch
        x_t = x[:, t, :, None]  # (B, D_in, 1)
        B_t = B[:, t, :, None]  # (B, N, 1)
        C_t = C[:, t, :, None]  # (B, N, 1)


        #   A: (N, N),    h_t: (B, N, D_in)  →  (B, N, D_in)

        # # set A
        # if A.shape == (N, N):  # matrix product
        #     A_prod_h_t = torch.einsum('nm,bmd->bnd', A, h_t)  # (B, N, D_in)
        # elif A.shape == (D_in, N):  # element-wise product
        #     A_prod_h_t = torch.einsum('dn,bnd->bnd', A, h_t)  # (B, N, D_in)
        # else:
        #     raise ValueError(f"found {A.shape = }, but should be {(D_in, N)=} or {(N, N)=}")

        # instead, simply assume A is identity operation
        A_prod_h_t = h_t

        # --- write into memory:  h_t = A @ h_t + B_t @ x_t^T  ---
        B_t_prod_x_t = torch.matmul(B_t, x_t.transpose(-2, -1))  # (B, N, D_in)
        h_t = A_prod_h_t + B_t_prod_x_t  # (B, N, D_in)

        # --- read from memory:  y_t = h_t^T @ C_t  ---
        #   h_t^T: (B, D_in, N),  C_t: (B, N, 1)  → (B, D_in, 1)
        y_t = torch.matmul(h_t.transpose(1, 2), C_t).squeeze(-1)  # (B, D_in, 1)

        # save
        h[:, t] = h_t
        y[:, t, :] = y_t.squeeze()

    return h, y


def compute_attention(x, A, B, C):

    batch_size, L, D_in = x.shape

    # 1) compute K once: shape (B, L, L)
    K = torch.einsum('bln,btn->blt', C, B)

    # 2) build a (L, L) matrix of exponents (l–t)
    i = torch.arange(L, device=x.device)
    j = torch.arange(L, device=x.device)
    diff = i[:, None] - j[None, :]  # shape (L, L)
    mask = diff >= 0  # lower‐triangle including diagonal
    exponents = torch.where(mask, diff, torch.zeros_like(diff))  # negative diffs turned to 0
    # powers = powers * mask  # zero‐out upper‐triangle

    # print(f"{K.shape = }, {powers.shape = }")

    # print(f"{x.shape = }, {K.shape = }")

    # 3) broadcast to (B, L, L) and multiply by K
    # alpha = K * powers.unsqueeze(0)  # shape (B, L, L)
    alpha = K * mask

    # 4) compute y_from_alpha
    y_ssm_from_alpha = torch.einsum('btl,bld->btd', alpha, x)
    # y_ssm_from_alpha = None

    return alpha, y_ssm_from_alpha


# the mismatch between the cumsum and logcumsumexp modes will grow quickly as sequence length scales up
if __name__ == "__main__":
    for length in [4, 8, 16, 32, 64, 128, 256]:
        u = -1 + 2 * torch.rand(2, length, 32)
        dt = torch.ones(2, length, 32)
        A =  -torch.rand(32, 16)
        B = torch.rand(2, length, 16)
        C = torch.rand(2, length, 16)
        D = torch.rand(32)
        
        output_cumsum = selective_scan(u, dt, A, B, C, D, mode='cumsum')
        output_logcumsumexp = selective_scan(u, dt, A, B, C, D, mode='logcumsumexp')
    
        print(f"mismatch at length {length} is {(output_cumsum - output_logcumsumexp).abs().max()}")
    