from pathlib import Path
import sys
import time
from torch import Tensor
from typing import List
from einops import einsum
import numpy as np
import matplotlib.pyplot as plt
sys.path.append('..')

import torch
from modeling.mamba2.modeling_mamba2_torch import Mamba2ForCausalLM
from transformers import AutoTokenizer
from get_prompt import get_long_prompt


def get_tensor_stats(t):
    mean = torch.mean(t)
    var = torch.var(t)
    median = torch.median(t)
    mn = torch.min(t)
    mx = torch.max(t)
    return mean, var, mn, mx, median


@torch.no_grad()
def get_decays(model: Mamba2ForCausalLM, input_ids: Tensor, chunk_size: int = 128, n_layers: int = 48):
    # Truncate to multiple of chunk_size (for simpler parallel processing)
    truncate_len = input_ids.shape[0] // chunk_size * chunk_size
    input_ids = input_ids[:truncate_len]
    print(f"Original length: {len(input_ids)}, truncated length: {truncate_len}")

    cur_state = None
    all_decays = []
    seqlen = len(input_ids)
    H = 24
    P = 64
    N = 128

    state = torch.tensor(0.0, device='cuda', dtype=torch.float32)
    # state = torch.zeros((P, N), device='cuda', dtype=torch.float32)
    # print(state)
    # exit()
    vals = []
    li = 20
    hi = 11
    ni = 91
    pi = 8

    print(f"Seq len: {seqlen}")
    for i in range(0, seqlen, chunk_size):
        this_inputs = input_ids[i:i + chunk_size].unsqueeze(0)
        print(f"Chunk {i} - {i + chunk_size}")
        output = model(this_inputs, states=cur_state, return_decays=True)
        decays = output['decays']
        all_decays.append(decays)

        cur_state = output['states']
        
        layer_state = output['states'][li].ssm_state[0, hi]  # (P, N)
        vals.append(layer_state[pi, ni])
        # print(layer_state.shape)
        # exit()
        comps = decays[li]

        for t in range(chunk_size):
            break
            A = comps['A']  # (H)
            dt = comps['dt']  # (B, T, H)
            B = comps['B']  # (B, T, N)
            x = comps['x']  # (B, T, H, P)
            # print(A.shape, dt.shape, B.shape, x.shape)
            # print('=======================')
            # print('A', A[hi])  # (1)
            # print('dt', dt[0, t, hi])  # (1)
            # print('B', B[0, t])  # (N)
            # print('x', x[0, t, hi])  # (P)
            # exit()
            
            # A = A[hi]
            # dt = dt[0, t, hi]
            # B = B[0, t]
            # x = x[0, t, hi]
            # state = state * torch.exp(dt * A) + dt * torch.einsum('n,p->pn', B, x)
            
            # Get target head and layer
            A = A[hi]  # (1)
            dt = dt[0, t, hi]  # (1)
            B = B[0, t, ni]  # (1)
            x = x[0, t, hi, pi]  # (1)
            
            dBx = dt * B * x
            state = state * torch.exp(dt * A) + dBx
            
            vals.append(state)
    
        # print(state.shape, layer_state.shape)
        # print("====================")
        # print(layer_state)
        # # print(state[pi, ni])
        # print(state)
        # print("====================")
        # print(layer_state[pi, ni])
        # exit()
        # assert layer_state[pi, ni] == vals[-1], f"{layer_state[pi, ni]}    {vals[-1]}"
    
        # compute_rec()
    vals = torch.stack(vals).cpu().float()
    xs = torch.arange(len(vals)) * chunk_size
    plt.figure(figsize=(4, 4))
    plt.plot(xs, vals)
    plt.axvline(x=train_len, color='r', linestyle='--')
    plt.tight_layout()
    plt.savefig("temp-manual.pdf", bbox_inches='tight')
    print('===============')
    print(layer_state.shape)
    print(layer_state[pi, ni])
    print(layer_state.var(), layer_state.mean())
    print("======")
    print(vals)
    exit()
    return all_decays


def smooth(
    x: np.ndarray,
    window_size: int = 4,
):
    length = x.shape[0]
    y = x.copy()
    for i in range(length):
        y[i] = np.mean(x[i - window_size : i])
    return y


def plot_stats(
    layer_to_dA: np.ndarray,
    layer_indices: List[int],
    dst_path: Path,
    bucket_size: int = 128,
    # max_len: int = 16 * 1024,
    max_len: int = 128,
):
    '''
    dA: (L, T, nheads)
    '''
    plt.figure(figsize=(4, 3))
    n_buckets = len(layer_to_dA[0]) // bucket_size
    # Plot decay and cumulative prod for each layer
    for layer_i in layer_indices:
        print(f"Plotting for layer {layer_i}")
        dA = layer_to_dA[layer_i]  # (T, nheads)
        dA = dA.mean(dim=1)  # (T)
        dA = dA[:max_len]  # (T)
        
        buckets = torch.split(dA, bucket_size)
        ys = torch.stack([bucket.float().mean() for bucket in buckets])
        xs = torch.arange(len(ys)) * bucket_size
        line1, = plt.plot(xs, ys, label=f"Layer {layer_i}", alpha=0.6)
        # color1 = line1.get_color()
        # plt.plot(xs, ds, '--', color=color1, label=f"L{layer_i} Product", alpha=0.6)
    plt.axvline(x=train_len, color='r', linestyle='--')
    plt.legend(loc='center left', bbox_to_anchor=(1.0, 0.5))
    plt.xlim(-10, max_len)
    plt.xlabel(r'Token position ($t$)')
    plt.ylabel(r'Decay ($\alpha_t$)')
    
    plt.tight_layout()
    print(f"Saving to {dst_path}")
    dst_path.parent.mkdir(exist_ok=True, parents=True)
    plt.savefig(dst_path, dpi=300, bbox_inches='tight')


def plot_heads(
    layer_dA: Tensor,
    dst_path: Path,
    head_indices: List[int],
    bucket_size: int = 512,
    max_len: int = 12 * 1024,
):
    """
    dA: (T, nheads)
    """
    n_chunks: int = len(layer_dA)
    plt.figure(figsize=(4, 3))
    for head_i in head_indices:
        buckets = torch.split(layer_dA[:, head_i], bucket_size)  # (n_buckets, bucket_size, nheads)
        ys = torch.stack([bucket.float().mean(dim=0) for bucket in buckets])  # (nheads)
        # ys = [torch.mean(layer_dA[i: i + bucket_size, head_i]) for i in range(0, max_len, bucket_size)]
        xs = torch.arange(len(ys)) * bucket_size
        plt.plot(xs, ys, label=f'Head {head_i}', alpha=0.6)
    plt.axvline(x=train_len, color='r', linestyle='--')
    plt.legend(loc='center left', bbox_to_anchor=(1.0, 0.5))
    plt.xlabel(r"Token position ($t$)")
    plt.ylabel(r"Decay ($\alpha_t$)")
    plt.ylim(0.999, 1.0)
    print(f"Saving to {dst_path}")
    dst_path.parent.mkdir(exist_ok=True, parents=True)
    plt.savefig(dst_path, dpi=300, bbox_inches='tight')
    plt.clf()


def plot_seq_mean(
    seq_mean,
    dst_path: Path,
):
    '''
    seq_mean: (L, nheads)
    '''
    L, H = seq_mean.shape
    ys = []
    xs = []
    for i in range(L):
        for j in range(H):
            ys.append(seq_mean[i, j])
            xs.append(i)
    plt.figure(figsize=(4, 3))
    plt.xlabel('Layer index')
    plt.ylabel(r'Decay ($\alpha_t$)')
    plt.scatter(xs, ys, alpha=0.4)
    dst_path.parent.mkdir(exist_ok=True, parents=True)
    plt.savefig(dst_path, dpi=300, bbox_inches='tight')
    plt.clf()


def main():
    model_size = '780m'
    global train_len
    train_len = 8 * 1024
    max_len = 32 * 1024
    ckpt_dir = Path('../ckpts')
    # pretrained_name = f'../../ckpts/mamba/mamba2-{model_size}'
    pretrained_name = f'mamba2-512-8/T8192_B8_GA1_P1_SR1_RD0_lr0.0005/ckpt_100000'
    pretrained_name = f'mamba2-370m/orig/ckpt_0'
    # pretrained_name = f'mamba2-370m/T8192_B1_GA1_P8_SR16_RD0_lr0.0005/ckpt_100000'
    pretrained_name = f'mamba2-130m/orig/ckpt_0'
    # pretrained_name = f'mamba2-130m/T2048_B32_GA1_P1_SR1_RD0_lr0.0005/ckpt_100000'
    # pretrained_name = f'mamba2-130m/T16384_B4_GA1_P1_SR1_RD0_lr0.0005/ckpt_100000'
    # pretrained_name = f'mamba2-130m/T4096_B16_GA1_P1_SR1_RD0_lr0.0005/ckpt_100000'

    prompt_name = 'nextlines'
    run_name = pretrained_name.replace('/', '--')
    tok_path = '../tokenizers/mamba-tok'
    figs_dir = Path("./figs/manual") / run_name / prompt_name
    cache_dir = Path('./cache/manual') / run_name / prompt_name
    
    file_ext = 'pdf'
    prompt = "The capital of China is"
    device = 'cuda'
    dtype = torch.float32
    NUM_TRIALS = 3
    LENGTH_PER_TRIAL = 100
    TEMPERATURE = 1.0
    TOP_P = 0.1
    n_layers = {
        '130m': 24,
        '370m': 48,
        '780m': 48,
    }[model_size]
    chunk_size = 128
    bucket_size = 128

    cache_dir.mkdir(exist_ok=True, parents=True)
    decays_path = cache_dir / f'decay-coeffs.pt'
    # states_path = Path(f'states-{model_size}.pt')
    if True or not decays_path.exists():
        print('====================================')
        print("Loading tokenizer...")
        tokenizer = AutoTokenizer.from_pretrained(tok_path, trust_remote_code=True)
        print("Loading model...")
        model = Mamba2ForCausalLM.from_pretrained(ckpt_dir / pretrained_name, device=device).to(dtype=dtype)
        # print(model)

        print("Tokenizing prompt...")
        prompt = get_long_prompt()
        input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(device=device)[0]
        print("Getting states...")
        decays = get_decays(model, input_ids, chunk_size=chunk_size, n_layers=n_layers)
        decays_path.parent.mkdir(exist_ok=True, parents=True)
        print(f"Saving stats to {decays_path}")
        torch.save(decays, decays_path)
    else:
        print(f"Loading states and stats from {decays_path}...")
        decays = torch.load(decays_path)

    # Move to CPU
    n_layers = len(decays[0])
    n_chunks = len(decays)
    layer_to_dA = []
    state = torch.zeros((64, 128))  # (P, N)
    vars_ = []
    means = []
    head_i = 11
    li = 20
    ni = 91
    pi = 8
    chunks = []
    ks = []
    vs = []
    for chunk_i in range(n_chunks):
        layer_coeffs = decays[chunk_i][li]
        dt = layer_coeffs['dt'][0][:, head_i].cpu().float()  # (chunk_size)
        A = layer_coeffs['A'][head_i].cpu().float()  # (1)
        
        B = layer_coeffs['B'][0].cpu().float()  # (chunk_size, N)
        x = layer_coeffs['x'][0][:, head_i].cpu().float()  # (chunk_size, P)

        ks.append(dt * B[:, ni])
        vs.append(x[:, pi])

        for t in range(chunk_size):
            add = dt[t] * einsum(B[t], x[t], 'n, p -> p n')
            state = state * torch.exp(A * dt[t]) + add
            if t == chunk_size - 1:
                vars_.append(state.var())
                means.append(state.mean())

    ks = torch.cat(ks)
    vs = torch.cat(vs)

    plt.figure(figsize=(4, 4))
    # buckets = torch.split(ks, bucket_size)  # (n_bucekts, bucket_size)
    # ks = torch.stack([b.mean() for b in buckets])
    # buckets = torch.split(vs, bucket_size)
    # vs = torch.stack([b.mean() for b in buckets])

    xs = torch.arange(len(ks)) * bucket_size
    plt.plot(xs, ks, label='k')
    # plt.plot(xs, vs, label='v')
    plt.legend()
    plt.grid(True)
    
    plt.axvline(x=train_len, color='r', linestyle='--')
    plt.tight_layout()
    print(f"Saving to {figs_dir}")
    plt.savefig(figs_dir / f"L{li}_H{head_i}_kv.{file_ext}", bbox_inches='tight')
    plt.clf()

    vars_ = torch.stack(vars_)  # (T)
    means = torch.stack(means)  # (T)

    # buckets = torch.split(vars_, bucket_size)  # (n_bucekts, bucket_size)
    # vars_ = torch.stack([b.mean() for b in buckets])
    # buckets = torch.split(means, bucket_size)
    # means = torch.stack([b.mean() for b in buckets])
    
    figs_dir.mkdir(exist_ok=True, parents=True)
    
    plt.figure(figsize=(4, 4))
    xs = torch.arange(len(vars_)) * bucket_size
    plt.plot(xs, vars_)
    plt.axvline(x=train_len, color='r', linestyle='--')
    plt.tight_layout()
    plt.grid(True)
    print(f"Saving to {figs_dir}")
    plt.savefig(figs_dir / f"L{li}_H{head_i}_vars.{file_ext}")
    plt.clf()

    plt.figure(figsize=(4, 4))
    xs = torch.arange(len(means)) * bucket_size
    plt.plot(xs, means)
    plt.axvline(x=train_len, color='r', linestyle='--')
    plt.tight_layout()
    plt.grid(True)
    print(f"Saving to {figs_dir}")
    plt.savefig(figs_dir / f"L{li}_H{head_i}_means.{file_ext}")
    plt.clf()


if __name__ == '__main__':
    main()
