import inspect
from collections import namedtuple
from copy import copy

import numpy as np
import torch
from torch import nn
import matplotlib.pyplot as plt
from typing import Iterable

from theory.theory import create_embedding_matrix, EmbeddingType
from theory.ideal_weights import create_ideal_model_weights
from utils.common import print_array_shapes, _named_arrays

ExtendedCausalLMOutput = namedtuple("ExtendedCausalLMOutput", ["logits", "H", "y"])


def extract_model_weights(model: nn.Module, model_class: str = 'mamba_tiny', print_output_shapes=False):

    match model_class:

        case 'mamba_tiny':

            E_in = model.embedding.weight.cpu().detach().T
            P_in = model.layers[0].mixer.in_proj_x.weight.cpu().detach()
            W = model.layers[0].mixer.conv1d.weight.cpu().squeeze().detach()
            S_B = model.layers[0].mixer.x_proj_to_B.weight.cpu().detach()
            S_C = model.layers[0].mixer.x_proj_to_C.weight.cpu().detach()
            P_out = model.layers[0].mixer.out_proj_y.weight.cpu().detach()
            E_out = model.lm_head.weight.cpu().detach()

        case _:
            raise NotImplementedError

    if print_output_shapes:
        print_array_shapes(['E_in', 'P_in', 'W', 'S_B', 'S_C', 'P_out', 'E_out'])

    return E_in, P_in, W, S_B, S_C, P_out, E_out


def compute_effective_embeddings(E_in, P_in, W, S_B, S_C, P_out, E_out, print_output_shapes=False):

    # make sure similar type
    vals  = [E_in,  P_in,  W,  S_B,  S_C,  P_out,  E_out]
    t0 = type(vals[0])
    assert all(type(v) is t0 for v in vals), "all inputs matrices must share the same type"

    # choose np or torch (both supported)
    if t0 == np.ndarray:
        m = np
    elif t0 == torch.Tensor:
        m = torch
    else:
        raise TypeError(f"unsupported input matrix type {t0}")


    # compute effective matrices

    W_p = W[:, 0]  # prev
    W_c = W[:, 1]  # curr

    W_p_diag = m.diag(W_p)
    W_c_diag = m.diag(W_c)

    E_hat_p_in = W_p_diag @ P_in @ E_in
    E_hat_c_in = W_c_diag @ P_in @ E_in

    # E_hat_in = m.hstack([E_hat_c_in, E_hat_p_in])
    E_hat_in = m.hstack([E_hat_p_in, E_hat_c_in])

    E_hat_out = E_out @ P_out

    E_tilde_k_in = S_B @ E_hat_in
    E_tilde_q_in = S_C @ E_hat_in

    if print_output_shapes:
        print_array_shapes(['E_hat_in', 'E_hat_out', 'E_tilde_q_in', 'E_tilde_k_in'])

    
    return E_hat_in, E_hat_out, E_tilde_q_in, E_tilde_k_in
    
    
def compute_projections(E_hat_in, E_hat_out, E_tilde_q_in, E_tilde_k_in, P_out, E_out, print_output_shapes=False):

    # view as projection matrices

    Pi_q_in = copy(E_tilde_q_in)
    Pi_k_in = copy(E_tilde_k_in)

    # # option 1 (equivalent)
    # Pi_v_in = P_out @ E_hat_in
    # Pi_v_out = copy(E_out)

    # option 2 (equivalent)
    Pi_v_in = copy(E_hat_in)
    Pi_v_out = copy(E_hat_out)

    if print_output_shapes:
        print_array_shapes(['Pi_q_in', 'Pi_k_in', 'Pi_v_in', 'Pi_v_out'])

    return Pi_q_in, Pi_k_in, Pi_v_in, Pi_v_out


def compute_Gram_matrices(Pi_q_in, Pi_k_in, Pi_v_in, Pi_v_out):

    G_kq = Pi_k_in.T @ Pi_q_in
    G_vv = Pi_v_out @ Pi_v_in

    return G_kq, G_vv


class SimplifiedLinearMamba(nn.Module):

    def __init__(self, V: int, D: int, N: int, device: str = 'cpu', dtype = None):

        super().__init__()

        device = torch.device(device)
        self.device = device

        self.V = V
        self.D = D
        self.N = N

        self.expand = 2  # we assume this
        self.D_in = int(D * self.expand)

        self.tensor_kwargs = {"dtype": dtype, "device": device}

        self.Pi_q_in = nn.Parameter(torch.randn(N, 2 * V, **self.tensor_kwargs))
        self.Pi_k_in = nn.Parameter(torch.randn(N, 2 * V, **self.tensor_kwargs))
        self.Pi_v_in = nn.Parameter(torch.randn(D, 2 * V, **self.tensor_kwargs))
        self.Pi_v_out = nn.Parameter(torch.randn(V, D, **self.tensor_kwargs))

        # helper matrices and projections
        U = torch.eye(self.V, **self.tensor_kwargs)  # one-hot
        I_half_col = torch.ones(self.V // 2, **self.tensor_kwargs)
        Z_half_col = torch.zeros(self.V // 2, **self.tensor_kwargs)
        k_mask = torch.hstack([I_half_col, Z_half_col])  # (V,)
        v_mask = torch.hstack([Z_half_col, I_half_col])  # (V,)
        k_proj = torch.diag(k_mask)  # (V, V)
        v_proj = torch.diag(v_mask)  # (V, V)

        # fixed weights
        self.U = nn.Parameter(U, requires_grad=False)
        self.k_proj = nn.Parameter(k_proj, requires_grad=False)
        self.v_proj = nn.Parameter(v_proj, requires_grad=False)


    def forward(self, input_ids: torch.Tensor, plot_H: bool = False) -> ExtendedCausalLMOutput:
        """
        input_ids: (B, L)
        returns:   ExtendedCausalLMOutput(logits=(B, L, V))
        """

        device = next(self.parameters()).device

        x_ids = copy(input_ids).to(self.device)
        B, L = x_ids.shape

        # Gram matrices
        G_kq, G_vv = compute_Gram_matrices(self.Pi_q_in, self.Pi_k_in, self.Pi_v_in, self.Pi_v_out)

        # hidden state
        H = torch.zeros(size=(B, self.V, self.V)).to(device)  # (B, V, V)
        H_ = torch.zeros(size=(B, self.D, self.N)).to(device)  # (B, D, N)
        H__ = torch.zeros(size=(B, self.V, self.V)).to(device)  # (B, V, V)

        y_logits = []

        x_ids_0 = torch.zeros_like(x_ids[:, 0]).to(device)  # zero

        for t in range(L):

            x_ids_curr = x_ids[:, t]
            x_ids_prev = x_ids[:, t - 1] if (t > 0) else x_ids_0

            x_prev = self.U[x_ids_prev]  # (B, V)
            x_curr = self.U[x_ids_curr]  # (B, V)

            # xi = torch.hstack([x_prev, x_curr])  # (B, 2V)
            xi = torch.hstack([x_curr, x_prev])  # (B, 2V)

            # original q, k, v
            q = x_curr @ self.k_proj.T  # (B, V)
            k = x_prev @ self.k_proj.T  # (B, V)
            v = x_curr @ self.v_proj.T  # (B, V)

            # compressed q, k, v
            q_ = xi @ self.Pi_q_in.T  # (B, N)
            k_ = xi @ self.Pi_k_in.T  # (B, N)
            v_ = xi @ self.Pi_v_in.T  # (B, 2D)

            # decompressed q, k, v
            k__ = xi @ G_kq  # (B, 2V)
            k__ = k__[:, :self.V]  # (B, V) -> curr only
            v__ = xi @ G_vv.T  # (B, V)

            # write
            write = lambda a, b: torch.bmm(a.unsqueeze(2), b.unsqueeze(2).transpose(1, 2))
            H += write(v, k)  # (B, V, V)
            H_ += write(v_, k_)  # (B, D, N)
            H__ += write(v__, k__)  # (B, V, V)

            if plot_H:
                normalize = lambda x: (x - x.min()) / (x.max() - x.min())  # * torch.sign(x.max() - x.min())
                matrices_to_plot = ['H', 'H_', 'H__']
                # matrices_to_plot = ['H', 'H__']
                fig, axes = plt.subplots(nrows=1, ncols=len(matrices_to_plot), figsize=(12, 3.5), facecolor='w')
                for ax, (name, M) in zip(axes, _named_arrays(matrices_to_plot).items()):
                    # cmap = 'inferno'
                    # cmap = 'viridis'
                    cmap = 'jet'
                    ax.imshow(normalize(M[0]), interpolation='none', aspect='auto', cmap=cmap, vmin=0, vmax=1)
                    ax.set_title(name)
                    # ax.set_title(f"{name} at {t=}")
                    # ax.colorbar()
                # plt.suptitle(f"{t = }")
                plt.show()

            # read
            read = lambda a, b: torch.bmm(a, b.unsqueeze(-1)).squeeze(-1)
            y = read(H, q)  # (B, V)
            y_ = read(H_, q_)  # (B, D)
            y__ = read(H__, q)  # (B, V)

            y_logits_t = y_ @ self.Pi_v_out.T  # (B, V)

            # store
            y_logits.append(y_logits_t)  # (B, V)

        # stack
        logits = torch.stack(y_logits, dim=1)  # (B, L, V)

        H_matrices = (H, H_, H__)
        y_vectors = (y, y_, y__)

        return ExtendedCausalLMOutput(logits=logits, H=H_matrices, y=y_vectors)


def set_simplified_linear_model_ideal_weights(model: nn.Module, **kwargs):

    # device = torch.device(device)

    V = model.V
    D = model.D
    N = model.N

    E_in, P_in, W, S_B, S_C, P_out, E_out = create_ideal_model_weights(V=V, D=D, N=N, **kwargs)

    E_hat_in, E_hat_out, E_tilde_q_in, E_tilde_k_in = compute_effective_embeddings(
        E_in, P_in, W, S_B, S_C, P_out, E_out,
    )

    Pi_q_in, Pi_k_in, Pi_v_in, Pi_v_out = compute_projections(
        E_hat_in, E_hat_out, E_tilde_q_in, E_tilde_k_in, P_out, E_out,
    )

    with torch.no_grad():
        model.Pi_q_in.copy_(Pi_q_in)
        model.Pi_k_in.copy_(Pi_k_in)
        model.Pi_v_in.copy_(Pi_v_in)
        model.Pi_v_out.copy_(Pi_v_out)


def plot_Gram_matrix_block_with_stats(
        M: torch.Tensor, M_name: str,
        left_axvline_values: list[float] = None,
        left_axhline_values: list[float] = None,
        right_axvline_values: list[float] = None,
        right_axhline_values: list[float] = None,
        figsize=(9, 4.5),
        print_stats: bool = True,
        stats_by_halves: bool = True,
):

    assert M.shape[0] == M.shape[1]

    V = M.shape[0]
    half_V = V // 2

    # M_sign = np.sign(float(torch.diag(M).mean()))
    # M *= M_sign

    M_diags = torch.diag(M)
    M_off_diags = M[~torch.eye(V).to(bool)]

    if isinstance(M, torch.Tensor):
        M = M.cpu().numpy()


    # plot

    fig, axes = plt.subplots(ncols=2, nrows=1, figsize=figsize)

    title = M_name

    # plt.suptitle(M_name)

    axes[0].imshow(M, interpolation='none')
    if left_axvline_values is not None:
        for x in left_axvline_values:
            axes[0].axvline(x, color='r', linestyle="-", alpha=0.5)
    if left_axhline_values is not None:
        for y in left_axhline_values:
            axes[0].axhline(y, color='r', linestyle="-", alpha=0.5)
    axes[0].set_title(title + "\nMatrix Values")
    axes[0].set_aspect('auto')

    axes[1].bar(range(V), np.diag(M))
    if right_axvline_values is not None:
        for x in right_axvline_values:
            axes[1].axvline(x, color='r', linestyle="-", alpha=1)
    if right_axhline_values is not None:
        for y in right_axhline_values:
            axes[1].axhline(y, color='b', linestyle="-", alpha=1)
    axes[1].set_title(title + "\nDiagonal Values")
    axes[1].set_aspect('auto')

    fig.show()


    # print

    if print_stats:

        print(f"stats: {M_name}")

        if stats_by_halves:
            print(f"\ndiag mean k value = ", float(M_diags[:half_V].mean()))
            print(f"diag mean v value = ", float(M_diags[half_V:].mean()))
            print(f"diag std k value = ", float(M_diags[:half_V].std()))
            print(f"diag std v value = ", float(M_diags[half_V:].std()))
        else:
            print(f"\ndiag mean value = ", float(M_diags.mean()))
            print(f"diag std value = ", float(M_diags.std()))

        print(f"\noff-diag mean abs value = ", float(torch.abs(M_off_diags).mean()))
        print(f"off-diag std value = ", float(M_off_diags.std()))
        print(f"off-diag max abs value = ", float(torch.abs(M_off_diags).max()))

        print("\n" + "-"*50 + "\n")