from pathlib import Path
import sys
import time
from torch import Tensor
import torch.nn.functional as F
from typing import List
import numpy as np
from einops import rearrange, einsum
import matplotlib.pyplot as plt
sys.path.append('..')

import torch
from modeling.mamba2.modeling_mamba2_torch import Mamba2ForCausalLM
from transformers import AutoTokenizer
from get_prompt import get_long_prompt


@torch.no_grad()
def get_decays(model: Mamba2ForCausalLM, input_ids: Tensor, chunk_size: int = 128, n_layers: int = 48):
    cur_state = None
    # print(input_ids)
    ssm_stats = {i: [] for i in range(n_layers)}
    conv_stats = {i: [] for i in range(n_layers)}
    all_decays = []
    seqlen = len(input_ids)
    print(f"Seq len: {seqlen}")
    for i in range(0, seqlen, chunk_size):
        this_inputs = input_ids[i:i + chunk_size].unsqueeze(0)
        # print(f"Chunk {i} - {i + chunk_size}")
        output = model(this_inputs, states=cur_state, return_decays=True)
        decays = output['decays']
        all_decays.append(decays)
    return all_decays


def plot_stats(
    outers: np.ndarray,
    layer_indices: List[int],
    dst_path: Path,
    bucket_size: int = 128,
    # max_len: int = 16 * 1024,
    max_len: int = 30 * 1024,
):
    '''
    outers: (nchunks, L, nheads, P, N)
    '''
    plt.figure(figsize=(4, 3))
    # Plot decay and cumulative prod for each layer
    
    
    fig, axs = plt.subplots(1, 2, figsize=(6, 3))
    axs[0].set_title('Mean')
    axs[1].set_title('Variance')
    for layer_i in layer_indices:
        print(f"Plotting for layer {layer_i}")
        outer = outers[:max_len // bucket_size, layer_i]  # (T, H, P, N)
        # print(K.shape)
        # K = np.mean(K, axis=1)  # (T, N, p)

        mean = outer.mean(axis=(1, 2, 3))
        var = outer.var(axis=(1, 2, 3))

        for ax, metric in zip(axs, [mean, var]):
            xs = torch.arange(len(metric)) * bucket_size
            print(f"Layer {layer_i=} {xs.shape=}")
            
            ax.plot(xs, metric, label=f"Layer {layer_i}", alpha=0.6)
            ax.set_xlabel(r'Token position ($t$)')
            ax.axvline(x=train_len, color='r', linestyle='--')
    axs[-1].legend(loc='center left', bbox_to_anchor=(1.0, 0.5))
    plt.xlim(-10, max_len)
    
    plt.tight_layout()
    
    print(f"Saving to {dst_path}")
    dst_path.parent.mkdir(exist_ok=True, parents=True)
    plt.savefig(dst_path, dpi=300, bbox_inches='tight')
    plt.clf()


def plot_heads(
    outers: Tensor,
    dst_path: Path,
    head_indices: List[int],
    bucket_size: int = 128,
    max_len: int = 30 * 1024,
):
    """
    outers: (nchunks, L, nheads, P, N)
    """
    plt.figure(figsize=(4, 3))
    fig, axs = plt.subplots(1, 2, figsize=(6, 3))
    axs[0].set_title('Mean')
    axs[1].set_title('Variance')
    for head_i in head_indices:
        # print(f"Plotting for head {head_i}")
        outer = outers[:, :max_len // bucket_size, head_i]  # (T // B, P, N)

        mean = outer.mean(axis=(1, 2))  # (T//B)
        var = outer.var(axis=(1, 2))  # (T//B)

        for ax, ys in zip(axs, [mean, var]):
            # ys = smooth(ys, window_size=bucket_size)
            # ys = [ys[i] for i in range(0, max_len, bucket_size)]
            xs = torch.arange(len(ys)) * bucket_size
            
            ax.plot(xs, ys, label=f"Head {head_i}", alpha=0.6)
            ax.set_xlabel(r'Token position ($t$)')
            ax.axvline(x=train_len, color='r', linestyle='--')
    
    axs[-1].legend(loc='center left', bbox_to_anchor=(1, 0.5))
    plt.tight_layout()
    
    print(f"Saving to {dst_path}")
    dst_path.parent.mkdir(exist_ok=True, parents=True)
    plt.savefig(dst_path, dpi=300, bbox_inches='tight')
    plt.clf()


def main():
    model_size = '780m'
    global train_len
    train_len = 8 * 1024
    max_len = 30 * 1024
    ckpt_dir = Path('../ckpts')
    pretrained_name = f'../../ckpts/mamba/mamba2-{model_size}'
    pretrained_name = f'mamba2-512-8/T8192_B8_GA1_P1_SR1_RD0_lr0.0005/ckpt_100000'
    pretrained_name = f'mamba2-370m/orig/ckpt_0'
    pretrained_name = f'mamba2-130m/orig/ckpt_0'
    # pretrained_name = f'../ckpts/mamba2-370m/T8192_B1_GA1_P8_SR16_RD0_lr0.0005/ckpt_100000'
    
    prompt_name = 'nextlines'
    run_name = pretrained_name.replace('/', '--')
    tok_path = '../tokenizers/mamba-tok'
    figs_dir = Path("./figs/outer") / run_name / prompt_name
    cache_dir = Path('./cache/outer') / run_name / prompt_name
    file_ext = 'pdf'
    prompt = "The capital of China is"
    device = 'cuda'
    dtype = torch.float32
    NUM_TRIALS = 3
    LENGTH_PER_TRIAL = 100
    TEMPERATURE = 1.0
    TOP_P = 0.1
    n_layers = {
        '130m': 24,
        '370m': 48,
        '780m': 48,
    }[model_size]
    bucket_size = 512

    cache_dir.mkdir(exist_ok=True, parents=True)
    decays_path = cache_dir / f'outer.pt'
    # states_path = Path(f'states-{model_size}.pt')
    if not decays_path.exists():
        print('====================================')
        print("Loading tokenizer...")
        tokenizer = AutoTokenizer.from_pretrained(tok_path, trust_remote_code=True)
        print(f"Loading model from {ckpt_dir / pretrained_name}")
        model = Mamba2ForCausalLM.from_pretrained(ckpt_dir / pretrained_name, device=device).to(dtype=dtype)

        # print(model)

        print("Tokenizing prompt...")
        prompt = get_long_prompt()
        input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(device=device)[0]
        print("Getting states...")
        truncate_len = input_ids.shape[0] // bucket_size * bucket_size
        print(f"Original length: {len(input_ids)}, truncated length: {truncate_len}")
        input_ids = input_ids[:truncate_len]
        decays = get_decays(model, input_ids, chunk_size=bucket_size, n_layers=n_layers)
        decays_path.parent.mkdir(exist_ok=True, parents=True)
        print(f"Saving stats to {decays_path}")
        torch.save(decays, decays_path)
    else:
        print(f"Loading states and stats from {decays_path}...")
        decays = torch.load(decays_path)

    # Move to CPU
    n_layers = len(decays[0])
    dt = []
    A = []
    target_head = 11
    target_layer = 20
    ni = 91
    pi = 8
    
    def plot_angles():
        pi = 8
        ni = 91
        n_chunks = len(decays)
        print(f"# chunks: {n_chunks}")

        vals = []
        for chunk_i in range(n_chunks):
            coeffs = decays[chunk_i][target_layer]
            dt = coeffs['dt'][0, :, target_head]  # (chunk_len)
            B = coeffs['B'][0, :, ni]  # (chunk_len)
            x = coeffs['x'][0, :, target_head, pi]  # (chunk_len)
            vals.append(dt * B * x)
        
        vals = torch.cat(vals).float().cpu()  # (T)
        
        with open('all_dbx.txt', 'w') as f:
            for i in range(len(vals)):
                f.write(str(float(vals[i].item())))
                f.write('\n')
        exit()
        
        xs = torch.arange(len(vals))
        plt.figure(figsize=(4, 4))
        plt.plot(xs, vals)
        plt.grid(True)
        plt.ylim(-0.00005, 0.00005)
        plt.tight_layout()
        dst_path = figs_dir / f'dbx_L{target_layer}_H{target_head}_P{pi}_N{ni}.pdf'
        print(f"Saving to {dst_path}")
        plt.savefig(dst_path, bbox_inches='tight')
        print(vals)
        # exit()

        all_outers = []
        for layer_i in range(n_layers):
            chunks = []
            for chunk_i in range(n_chunks):
                layer_coeffs = decays[chunk_i][layer_i]
                dt = layer_coeffs['dt'][0]  # (chunk_len, H)
                B = layer_coeffs['B'][0]  # (chunk_len, N)
                x = layer_coeffs['x'][0]  # (chunk_len, H, P)
                # dt = rearrange(dt, 'c h -> c h 1 1')  # (chunk_len, H, 1, 1)
                # B = rearrange(B, 'c n -> c 1 1 n')  # (chunk_len, 1, 1, N)
                # x = rearrange(x, 'c h p -> c h p 1')  # (chunk_len, H, P, 1)
                outer = einsum(dt, B, x, 'c h, c n, c h p -> c h p n')
                
                outer = outer[:, target_head, pi]  # (C, N)
                chunks.append(outer)
            all_outers.append(torch.cat(chunks))  # [(T, N)]

        outers = torch.stack(all_outers).float().cpu()  # (L, T, N)
        outers = outers[target_layer]  # (T, N)
        print(outers.shape)
        plt.figure(figsize=(4, 4))
        n_chunk_size = 8
        for ni_lo in range(0, 128, n_chunk_size):
            dst_path = figs_dir / f'ni-{ni_lo}-{ni_lo + n_chunk_size}.{file_ext}'
            nis = list(range(ni_lo, ni_lo + n_chunk_size))
            xs = torch.arange(len(outers))
            for ni in nis:
                plt.plot(xs, outers[:, ni], label=f"{ni}", alpha=0.5)
            plt.legend(loc='center left', bbox_to_anchor=(1.0, 0.5))
            plt.tight_layout()
            print(f"Saving to {dst_path}")
            dst_path.parent.mkdir(exist_ok=True, parents=True)
            plt.savefig(dst_path, bbox_inches='tight')
            plt.clf()

    plot_angles()

    def plot_dtx():
        pi = 20
        n_chunks = len(decays)
        print(f"# chunks: {n_chunks}")
        all_outers = []
        for layer_i in range(n_layers):
            chunks = []
            for chunk_i in range(n_chunks):
                layer_coeffs = decays[chunk_i][layer_i]
                dt = layer_coeffs['dt'][0]  # (chunk_len, H)
                # B = layer_coeffs['B'][0]  # (chunk_len, N)
                x = layer_coeffs['x'][0]  # (chunk_len, H, P)
                # dt = rearrange(dt, 'c h -> c h 1 1')  # (chunk_len, H, 1, 1)
                # B = rearrange(B, 'c n -> c 1 1 n')  # (chunk_len, 1, 1, N)
                # x = rearrange(x, 'c h p -> c h p 1')  # (chunk_len, H, P, 1)
                outer = einsum(dt, x, 'c h, c h p -> c h p')
                outer = outer[:, target_head, pi]  # (C,)
                chunks.append(outer)
            all_outers.append(torch.cat(chunks))  # [(T,)]

        outers = torch.stack(all_outers).float().cpu()  # (L, T,)
        outers = outers[target_layer]  # (T,)
        print(outers.shape)
        plt.figure(figsize=(4, 4))
        dst_path = figs_dir / f'dtx_L{target_layer}_H{target_head}_p{pi}.{file_ext}'
        xs = torch.arange(len(outers))
        plt.plot(xs, outers, alpha=0.5)
        plt.axvline(x=train_len, color='r', linestyle='--')
        plt.tight_layout()
        print(f"Saving to {dst_path}")
        plt.savefig(dst_path, bbox_inches='tight')
        plt.clf()

    # plot_dtx()
    
    bucket_size = 1
    n_chunks = len(decays)
    print(f"# chunks: {n_chunks}")
    layer_to_outer = []
    for layer_i in range(n_layers):
        outers = []
        for chunk_i in range(n_chunks):
            layer_coeffs = decays[chunk_i][layer_i]
            dt = layer_coeffs['dt'][0]  # (chunk_len, H)
            B = layer_coeffs['B'][0]  # (chunk_len, N)
            x = layer_coeffs['x'][0]  # (chunk_len, H, P)
            # dt = rearrange(dt, 'c h -> c h 1 1')  # (chunk_len, H, 1, 1)
            # B = rearrange(B, 'c n -> c 1 1 n')  # (chunk_len, 1, 1, N)
            # x = rearrange(x, 'c h p -> c h p 1')  # (chunk_len, H, P, 1)
            outer = einsum(dt, B, x, 'c h, c n, c h p -> c h p n')
            buckets = torch.split(outer, bucket_size)  # (b, c / b, h, p, n)
            outer = torch.stack([b.mean(dim=0) for b in buckets])  # (b, h, p, n)
            # outer = outer.mean(dim=0)  # (H, P, N), Get mean of bucket
            outers.append(outer)
        layer_to_outer.append(torch.cat(outers))  # [(T, H, P, N), ]

    outers = torch.stack(layer_to_outer, dim=0).float().cpu()  # (L, T, H, P, N)  # 400M floats

    figs_dir.mkdir(exist_ok=True, parents=True)

    def plot_all_layers():
        layer_chunk_size = 8
        for layer_lo in range(0, n_layers, layer_chunk_size):
            layer_indices = list(range(layer_lo, layer_lo + layer_chunk_size))
            print(f"Plotting for layers: {layer_lo} - {layer_lo + layer_chunk_size}")
            plot_stats(
                outers,
                layer_indices,
                # figs_dir / f"Layer{layer_lo}-{layer_lo + layer_chunk_size}.{file_ext}",
                figs_dir / f"Layer-{layer_lo}-{layer_lo + layer_chunk_size}.{file_ext}",
                bucket_size=bucket_size,
                max_len=max_len,
            )
        
    # plot_all_layers()

    # nheads = len(K[0][0])
    def plot_all_heads():
        head_chunk_size = 6
        nheads = 24
        print(f'# heads: {nheads}')
        for target_layer in range(n_layers):
            print(f"Plotting the individual heads for layer {target_layer}")
            for head_lo in range(0, nheads, head_chunk_size):
                head_indices = list(range(head_lo, head_lo + head_chunk_size))
                # print(f"Plotting for heads: {head_indices}")
                dst_dir = figs_dir / f'layer-{target_layer}'
                dst_dir.mkdir(exist_ok=True, parents=True)
                dst_path = dst_dir / f'Head-{head_lo}-{head_lo + head_chunk_size}.{file_ext}'
                plot_heads(
                    outers[target_layer],
                    head_indices=head_indices,
                    dst_path=dst_path,
                    bucket_size=bucket_size,
                    max_len=max_len,
                )

    plot_all_heads()
    
    target_layer = 20
    target_head = 11
    pi = 8
    ni = 91
    
    # Plot one specific channel in state
    vals = outers[:, target_layer, target_head, pi, ni]
    xs = torch.arange(len(vals)) * bucket_size
    plt.figure(figsize=(4, 4))
    plt.plot(xs, vals)
    plt.axvline(x=train_len, color='r', linestyle='--')
    plt.tight_layout()
    dst_path = figs_dir / f'kv-P{pi}-N{ni}.{file_ext}'
    print(f"Saving to {dst_path}")
    plt.savefig(dst_path, bbox_inches='tight')
    plt.clf()
    
    vecs = outers[:, target_layer, target_head, pi]  # (T, N)
    vars_ = []
    means = []
    
    T = len(vecs)
    for i in range(T):
        var = vecs[i].var()
        mean = vecs[i].mean()
        vars_.append(var)
        means.append(mean)
    
    plt.figure(figsize=(2.4, 2.4))
    xs = np.arange(len(vars_)) * bucket_size  # (C)
    plt.plot(xs, vars_)
    plt.axvline(x=train_len, color='r', linestyle='--')
    # plt.legend(loc='center left', bbox_to_anchor=(1.0, 0.5))
    dst_path = figs_dir / f'kv-p{pi}_var.{file_ext}'
    print(f"Saving to {dst_path}")
    plt.savefig(dst_path, bbox_inches='tight')
    plt.clf()

    plt.figure(figsize=(2.4, 2.4))
    xs = np.arange(len(means)) * bucket_size  # (C)
    plt.plot(xs, means)
    plt.axvline(x=train_len, color='r', linestyle='--')
    # plt.legend(loc='center left', bbox_to_anchor=(1.0, 0.5))
    dst_path = figs_dir / f'kv-p{pi}_mean.{file_ext}'
    print(f"Saving to {dst_path}")
    plt.savefig(dst_path, bbox_inches='tight')
    plt.clf()
    
    cossims = []
    angles = []
    for i in range(T - 1):
        cossim = F.cosine_similarity(vecs[i].unsqueeze(0), vecs[i + 1].unsqueeze(0))
        angle = torch.acos(cossim)
        angles.append(angle)
        cossims.append(cossim)

    plt.figure(figsize=(2.4, 2.4))
    xs = np.arange(len(angles)) * bucket_size  # (C)
    plt.plot(xs, angles)
    plt.axvline(x=train_len, color='r', linestyle='--')
    # plt.legend(loc='center left', bbox_to_anchor=(1.0, 0.5))
    dst_path = figs_dir / f'kv-p{pi}_angle.{file_ext}'
    print(f"Saving to {dst_path}")
    plt.savefig(dst_path, bbox_inches='tight')
    plt.clf()

    bucket_size = 512
    plt.figure(figsize=(2.4, 2.4))
    # cossims = smooth_tensor(torch.tensor(cossims), bucket_size)
    xs = np.arange(len(cossims)) * bucket_size  # (C)
    plt.plot(xs, cossims)
    plt.axvline(x=train_len, color='r', linestyle='--')
    # plt.legend(loc='center left', bbox_to_anchor=(1.0, 0.5))
    dst_path = figs_dir / f'kv-p{pi}_cossim.{file_ext}'
    print(f"Saving to {dst_path}")
    plt.savefig(dst_path, bbox_inches='tight')
    plt.clf()


if __name__ == '__main__':
    main()
