from pathlib import Path
import sys
import time
from torch import Tensor
from typing import List
import numpy as np
import matplotlib.pyplot as plt
sys.path.append('..')

import torch
from modeling.mamba2.modeling_mamba2_torch import Mamba2ForCausalLM
from transformers import AutoTokenizer


def get_long_prompt():
    return 'The capital of China is Beijing. The capital of USA is Washington. The capital of Norway is Oslo. ' * 2400  # 1200 -> 25K tokens


def get_tensor_stats(t):
    mean = torch.mean(t)
    var = torch.var(t)
    median = torch.median(t)
    mn = torch.min(t)
    mx = torch.max(t)
    return mean, var, mn, mx, median


@torch.no_grad()
def get_decays(model: Mamba2ForCausalLM, input_ids: Tensor, chunk_size: int = 128, n_layers: int = 48):
    cur_state = None
    # print(input_ids)
    ssm_stats = {i: [] for i in range(n_layers)}
    conv_stats = {i: [] for i in range(n_layers)}
    all_decays = []
    seqlen = len(input_ids)
    print(f"Seq len: {seqlen}")
    for i in range(0, seqlen, chunk_size):
        this_inputs = input_ids[i:i + chunk_size].unsqueeze(0)
        # print(f"Chunk {i} - {i + chunk_size}")
        output = model(this_inputs, states=cur_state, return_decays=True)
        decays = output['decays']
        all_decays.append(decays)
    return all_decays


def smooth(
    x: np.ndarray,
    window_size: int = 4,
):
    length = x.shape[0]
    y = x.copy()
    for i in range(length):
        y[i] = np.mean(x[i - window_size : i])
    return y


def plot_stats(
    layer_to_dA: np.ndarray,
    layer_indices: List[int],
    dst_path: Path,
    bucket_size: int = 128,
    # max_len: int = 16 * 1024,
    max_len: int = 128,
):
    '''
    dA: (L, T, nheads)
    '''
    plt.figure(figsize=(4, 3))
    n_buckets = len(layer_to_dA[0]) // bucket_size
    # Plot decay and cumulative prod for each layer
    for layer_i in layer_indices:
        print(f"Plotting for layer {layer_i}")
        dA = layer_to_dA[layer_i]  # (L, T)
        dA = np.mean(dA, axis=1)  # (T)
        ys = dA[:max_len]  # (T)
        # ds = np.array([np.prod(ys[i:max_len]) for i in range(max_len)])
        
        # Get the smooths decay and coeff for each bucket.
        ys = smooth(ys, window_size=bucket_size)
        ys = [ys[i] for i in range(0, max_len, bucket_size)]
        # ds = [ds[i] for i in range(0, max_len, bucket_size)]
        xs = [i * bucket_size for i in range(len(ys))]
        line1, = plt.plot(xs, ys, label=f"Layer {layer_i}", alpha=0.6)
        # color1 = line1.get_color()
        # plt.plot(xs, ds, '--', color=color1, label=f"L{layer_i} Product", alpha=0.6)
    plt.axvline(x=train_len, color='r', linestyle='--')
    plt.legend(loc='center left', bbox_to_anchor=(1.0, 0.5))
    plt.xlim(-10, max_len)
    plt.xlabel(r'Token position ($t$)')
    plt.ylabel(r'Decay ($\alpha_t$)')
    
    plt.tight_layout()
    print(f"Saving to {dst_path}")
    dst_path.parent.mkdir(exist_ok=True, parents=True)
    plt.savefig(dst_path, dpi=300, bbox_inches='tight')


def plot_heads(
    layer_dA: Tensor,
    dst_path: Path,
    head_indices: List[int],
    bucket_size: int = 128,
    max_len: int = 12 * 1024,
):
    """
    dA: (T, nheads)
    """
    n_chunks: int = len(layer_dA)
    plt.figure(figsize=(4, 3))
    for head_i in head_indices:
        ys = [np.mean(layer_dA[i: i + bucket_size, head_i]) for i in range(0, max_len, bucket_size)]
        xs = [bucket_size * x for x in range(len(ys))]
        plt.plot(xs, ys, label=f'Head {head_i}', alpha=0.6)
    plt.axvline(x=train_len, color='r', linestyle='--')
    plt.xlabel(r"Token position ($t$)")
    plt.ylabel(r"Decay ($\alpha_t$)")
    print(f"Saving to {dst_path}")
    dst_path.parent.mkdir(exist_ok=True, parents=True)
    plt.savefig(dst_path, dpi=300, bbox_inches='tight')
    plt.clf()


def main():
    model_size = '780m'
    global train_len
    train_len = 8 * 1024
    max_len = 32 * 1024
    ckpt_dir = Path('../ckpts')
    pretrained_name = f'../../ckpts/mamba/mamba2-{model_size}'
    pretrained_name = f'mamba2-512-8/T8192_B8_GA1_P1_SR1_RD0_lr0.0005/ckpt_100000'
    pretrained_name = f'mamba2-370m/orig/ckpt_0'
    # pretrained_name = f'../ckpts/mamba2-370m/T8192_B1_GA1_P8_SR16_RD0_lr0.0005/ckpt_100000'
    run_name = pretrained_name.replace('/', '--')
    tok_path = '../tokenizers/mamba-tok'
    figs_dir = Path("./figs/decays") / run_name
    file_ext = 'pdf'
    prompt = "The capital of China is"
    device = 'cuda'
    dtype = torch.bfloat16
    NUM_TRIALS = 3
    LENGTH_PER_TRIAL = 100
    TEMPERATURE = 1.0
    TOP_P = 0.1
    n_layers = {
        '130m': 24,
        '370m': 48,
        '780m': 48,
    }[model_size]
    chunk_size = 512

    cache_dir = Path('./cache/decays') / run_name
    cache_dir.mkdir(exist_ok=True, parents=True)
    decays_path = cache_dir / f'decay-coeffs.pt'
    # states_path = Path(f'states-{model_size}.pt')
    if not decays_path.exists():
        print('====================================')
        print("Loading tokenizer...")
        tokenizer = AutoTokenizer.from_pretrained(tok_path, trust_remote_code=True)
        print("Loading model...")
        model = Mamba2ForCausalLM.from_pretrained(ckpt_dir / pretrained_name, device=device).to(dtype=dtype)

        print(model)

        print("Tokenizing prompt...")
        prompt = get_long_prompt()
        input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(device=device)[0]
        print("Getting states...")
        truncate_len = input_ids.shape[0] // chunk_size * chunk_size
        print(f"Original length: {len(input_ids)}, truncated length: {truncate_len}")
        input_ids = input_ids[:truncate_len]
        decays = get_decays(model, input_ids, chunk_size=chunk_size, n_layers=n_layers)
        decays_path.parent.mkdir(exist_ok=True, parents=True)
        print(f"Saving stats to {decays_path}")
        torch.save(decays, decays_path)
    else:
        print(f"Loading states and stats from {decays_path}...")
        decays = torch.load(decays_path)

    # Move to CPU
    n_layers = len(decays[0])
    dt = []
    A = []
    
    n_chunks = len(decays)
    layer_to_dA = [[] for _ in range(n_layers)]
    print(f"# chunks: {n_chunks}")
    for chunk_i in range(n_chunks):
        for layer_i in range(n_layers):
            
            layer_coeffs = decays[chunk_i][layer_i]
            dt = layer_coeffs['dt'][0]  # (T, nheads)
            A = layer_coeffs['A']  # (nheads)
            dA = torch.exp(dt * A)  # (T, nheads)
            layer_to_dA[layer_i].append(dA)
    
    for layer_i in range(n_layers):
        layer_to_dA[layer_i] = torch.cat(layer_to_dA[layer_i], dim=0).float().cpu()
    layer_to_dA = np.array(layer_to_dA)  # (L, T, nheads)
    print(layer_to_dA)
    print(layer_to_dA.shape)

    figs_dir.mkdir(exist_ok=True, parents=True)
    nheads = len(layer_to_dA[0][0])

    def plot_all_layers():
        layer_chunk_size = 8
        for layer_lo in range(0, n_layers, layer_chunk_size):
            layer_indices = list(range(layer_lo, layer_lo + layer_chunk_size))
            print(f"Plotting for layers: {layer_lo} - {layer_lo + layer_chunk_size}")
            plot_stats(
                layer_to_dA,
                layer_indices,
                figs_dir / "dA" / f"Layer{layer_lo}-{layer_lo + layer_chunk_size}.{file_ext}",
                max_len=max_len,
            )
        
    plot_all_layers()

    def plot_all_heads():
        head_chunk_size = 16
        print(f'# heads: {nheads}')
        for target_layer in range(n_layers):
            print(f"Plotting the individual heads for layer {target_layer}")
            for head_lo in range(0, nheads, head_chunk_size):
                head_indices = list(range(head_lo, head_lo + head_chunk_size))
                print(f"Plotting for heads: {head_indices}")
                dst_dir = figs_dir / f'layer-{target_layer}'
                dst_dir.mkdir(exist_ok=True, parents=True)
                dst_path = dst_dir / "dA" / f'Head-{head_lo}-{head_lo + head_chunk_size}.{file_ext}'
                plot_heads(
                    layer_to_dA[target_layer],
                    head_indices=head_indices,
                    dst_path=dst_path,
                    max_len=max_len,
                )

    plot_all_heads()


if __name__ == '__main__':
    main()
