from pathlib import Path
import sys
import time
from torch import Tensor
from einops import einsum
from typing import List
import numpy as np
import matplotlib.pyplot as plt
sys.path.append('..')

import torch
from modeling.mamba2.modeling_mamba2_torch import Mamba2ForCausalLM
from transformers import AutoTokenizer
from get_prompt import get_long_prompt


def get_tensor_stats(t):
    mean = torch.mean(t)
    var = torch.var(t)
    median = torch.median(t)
    mn = torch.min(t)
    mx = torch.max(t)
    return mean, var, mn, mx, median


@torch.no_grad()
def get_decays(model: Mamba2ForCausalLM, input_ids: Tensor, chunk_size: int = 128, n_layers: int = 48):
    cur_state = None
    # print(input_ids)
    ssm_stats = {i: [] for i in range(n_layers)}
    conv_stats = {i: [] for i in range(n_layers)}
    all_decays = []
    seqlen = len(input_ids)
    print(f"Seq len: {seqlen}")
    for i in range(0, seqlen, chunk_size):
        this_inputs = input_ids[i:i + chunk_size].unsqueeze(0)
        # print(f"Chunk {i} - {i + chunk_size}")
        output = model(this_inputs, states=cur_state, return_decays=True)
        decays = output['decays']
        cur_state = output['states']
        all_decays.append(decays)
    return all_decays


def smooth(
    x: np.ndarray,
    window_size: int = 4,
):
    length = x.shape[0]
    y = x.copy()
    for i in range(length):
        y[i] = np.mean(x[i - window_size : i])
    return y


def plot_stats(
    K: np.ndarray,
    layer_indices: List[int],
    dst_path: Path,
    bucket_size: int = 128,
    # max_len: int = 16 * 1024,
    max_len: int = 128,
):
    '''
    K: (L, T, nheads, N)
    '''
    plt.figure(figsize=(4, 3))
    n_buckets = len(K) // bucket_size
    
    fig, axs = plt.subplots(1, 2, figsize=(6, 3))
    axs[0].set_title('Mean')
    axs[1].set_title('Variance')
    for layer_i in layer_indices:
        print(f"Plotting for layer {layer_i}")
        layer_k = K[layer_i, :max_len]  # (T, nheads, N)
        # Mean across heads
        layer_k = layer_k.mean(dim=1)  # (T, N)

        mean = layer_k.mean(dim=1)  # (T)
        var = layer_k.var(dim=1)  # (T)

        for ax, ys in zip(axs, [mean, var]):
            buckets = torch.split(ys, bucket_size)
            ys = torch.stack([bucket.float().mean() for bucket in buckets])
            xs = torch.arange(len(ys)) * bucket_size
            
            ax.plot(xs, ys, label=f"Layer {layer_i}", alpha=0.6)
            ax.set_xlabel(r'Token position ($t$)')
            ax.axvline(x=train_len, color='r', linestyle='--')
    axs[-1].legend(loc='center left', bbox_to_anchor=(1.0, 0.5))
    # plt.xlim(-10, max_len)
    
    plt.tight_layout()
    
    print(f"Saving to {dst_path}")
    dst_path.parent.mkdir(exist_ok=True, parents=True)
    plt.savefig(dst_path, dpi=300, bbox_inches='tight')


def plot_heads(
    layer_K: Tensor,
    dst_path: Path,
    head_indices: List[int],
    bucket_size: int = 128,
    max_len: int = 12 * 1024,
):
    """
    K: (T, nheads, N)
    """
    plt.figure(figsize=(4, 3))
    fig, axs = plt.subplots(1, 2, figsize=(6, 3))
    axs[0].set_title('Mean')
    axs[1].set_title('Variance')
    for head_i in head_indices:
        K = layer_K[:max_len, head_i]  # (max_len, N)

        mean = K.mean(axis=1)  # (T)
        var = K.var(axis=1)  # (T)

        for ax, ys in zip(axs, [mean, var]):
            buckets = torch.split(ys, bucket_size)
            ys = torch.stack([bucket.float().mean() for bucket in buckets])
            xs = torch.arange(len(ys)) * bucket_size
            
            ax.plot(xs, ys, label=f"Head {head_i}", alpha=0.6)
            ax.set_xlabel(r'Token position ($t$)')
            ax.axvline(x=train_len, color='r', linestyle='--')
    
    axs[-1].legend(loc='center left', bbox_to_anchor=(1.0, 0.5))
    plt.tight_layout()
    
    print(f"Saving to {dst_path}")
    dst_path.parent.mkdir(exist_ok=True, parents=True)
    plt.savefig(dst_path, dpi=300, bbox_inches='tight')
    plt.clf()


def main():
    model_size = '780m'
    global train_len
    train_len = 8 * 1024
    max_len = 20 * 1024
    ckpt_dir = Path('../ckpts')
    pretrained_name = f'../../ckpts/mamba/mamba2-{model_size}'
    pretrained_name = f'mamba2-512-8/T8192_B8_GA1_P1_SR1_RD0_lr0.0005/ckpt_100000'
    pretrained_name = f'mamba2-370m/orig/ckpt_0'
    pretrained_name = f'mamba2-130m/orig/ckpt_0'
    # pretrained_name = f'../ckpts/mamba2-370m/T8192_B1_GA1_P8_SR16_RD0_lr0.0005/ckpt_100000'
    
    prompt_name = 'nextlines'
    prompt_name = 'capital'
    run_name = pretrained_name.replace('/', '--')
    tok_path = '../tokenizers/mamba-tok'
    figs_dir = Path("./figs/K") / run_name / prompt_name
    cache_dir = Path('./cache/K') / run_name / prompt_name
    file_ext = 'pdf'
    prompt = "The capital of China is"
    device = 'cuda'
    dtype = torch.float32
    NUM_TRIALS = 3
    LENGTH_PER_TRIAL = 100
    TEMPERATURE = 1.0
    TOP_P = 0.1
    n_layers = {
        '130m': 24,
        '370m': 48,
        '780m': 48,
    }[model_size]
    chunk_size = 512

    cache_dir.mkdir(exist_ok=True, parents=True)
    decays_path = cache_dir / f'K.pt'
    # states_path = Path(f'states-{model_size}.pt')
    if not decays_path.exists():
        print('====================================')
        print("Loading tokenizer...")
        tokenizer = AutoTokenizer.from_pretrained(tok_path, trust_remote_code=True)
        print(f"Loading model from {ckpt_dir / pretrained_name}")
        model = Mamba2ForCausalLM.from_pretrained(ckpt_dir / pretrained_name, device=device).to(dtype=dtype)

        print(model)

        print("Tokenizing prompt...")
        prompt = get_long_prompt()
        input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(device=device)[0]
        print("Getting states...")
        truncate_len = input_ids.shape[0] // chunk_size * chunk_size
        print(f"Original length: {len(input_ids)}, truncated length: {truncate_len}")
        input_ids = input_ids[:truncate_len]
        decays = get_decays(model, input_ids, chunk_size=chunk_size, n_layers=n_layers)
        decays_path.parent.mkdir(exist_ok=True, parents=True)
        print(f"Saving stats to {decays_path}")
        torch.save(decays, decays_path)
    else:
        print(f"Loading states and stats from {decays_path}...")
        decays = torch.load(decays_path)

    # Move to CPU
    n_layers = len(decays[0])
    dt = []
    A = []
    
    n_chunks = len(decays)
    layer_to_K = []
    print(f"# chunks: {n_chunks}")
    for layer_i in range(n_layers):
        chunk_Ks = []
        for chunk_i in range(n_chunks):
            layer_coeffs = decays[chunk_i][layer_i]
            dt = layer_coeffs['dt'][0]  # (chunk_size, nheads)
            B = layer_coeffs['B'][0]  # (chunk_size, N)
            K = einsum(dt, B, 'c h, c n -> c h n')  # (chunk_size, nheads, N)
            chunk_Ks.append(K)
        layer_to_K.append(torch.cat(chunk_Ks))  # [(T, nheads, N)]
    
    K = torch.stack(layer_to_K).cpu().float()  # (L, T, nheads, N)
    print(K.shape)
    figs_dir.mkdir(exist_ok=True, parents=True)

    def plot_all_layers():
        layer_chunk_size = 8
        for layer_lo in range(0, n_layers, layer_chunk_size):
            layer_indices = list(range(layer_lo, layer_lo + layer_chunk_size))
            print(f"Plotting for layers: {layer_lo} - {layer_lo + layer_chunk_size}")
            plot_stats(
                K,
                layer_indices,
                figs_dir / f"Layer{layer_lo}-{layer_lo + layer_chunk_size}.{file_ext}",
                # figs_dir / f"Layer-30.{file_ext}",
                max_len=max_len,
            )

    plot_all_layers()

    nheads = K.shape[2]
    def plot_all_heads():
        head_chunk_size = 1
        print(f'# heads: {nheads}')
        # for target_layer in range(n_layers):
        for target_layer in [20]:
            print(f"Plotting the individual heads for layer {target_layer}")
            for head_lo in range(0, nheads, head_chunk_size):
                head_indices = list(range(head_lo, head_lo + head_chunk_size))
                print(f"Plotting for heads: {head_indices}")
                dst_dir = figs_dir / f'layer-{target_layer}'
                dst_dir.mkdir(exist_ok=True, parents=True)
                dst_path = dst_dir / f'Head-{head_lo}-{head_lo + head_chunk_size}.{file_ext}'
                plot_heads(
                    K[target_layer],
                    head_indices=head_indices,
                    dst_path=dst_path,
                    max_len=max_len,
                )

    plot_all_heads()

    ni = 91
    li = 20
    hi = 5

    for hi in range(24):
        vals = K[li, :, hi]
        print('head', hi, vals.mean(), vals.var())

    vals = K[li, :, hi, ni]  # (T)
    xs = torch.arange(len(vals))
    plt.figure(figsize=(8, 4))
    plt.plot(xs, vals, linewidth=0.7, alpha=0.5)
    plt.savefig(figs_dir / f'L{li}_H{hi}_N{ni}.{file_ext}')
        
    chunk_size = 128
    for i in range(0, len(vals), chunk_size):
        chunk = vals[i : i + chunk_size]
        print(i, i + chunk_size, chunk.mean(), chunk.var())
    


if __name__ == '__main__':
    main()
