import sys

import torch
from transformers import AutoTokenizer

sys.path.append('../..')
from modeling.mamba.mamba_torch import MambaMixer
from modeling.mamba import MambaConfig, MambaState, MambaModel, MambaForCausalLM


def compare_seq_and_par():
    B = 1
    T = 2048
    D = 768
    E = 2
    N = 16
    L = 24
    dtype = torch.float32
    dtype_str = 'fp32'
    d_conv = 4
    states = MambaState.empty(L, B, D, E, N, d_conv, device='cuda', dtype=dtype)
    # cache = None
    # input_ids = torch.randint(low=0, high=1024, size=(B, T), dtype=torch.long, device='cuda')
    tokenizer = AutoTokenizer.from_pretrained('state-spaces/mamba-130m-hf')
    input_ids = tokenizer('An increasing sequence: one, two, three, four,', return_tensors='pt').input_ids.cuda()

    state_dict = torch.load('../../mamba-130m.pt')
    seq_config = MambaConfig(n_layers=L, d_model=D, d_state=N, expand_factor=E, d_conv=d_conv, scan_impl='seq_triton', dtype=dtype_str)
    seq_model = MambaForCausalLM(seq_config).cuda()
    seq_model.model.load_state_dict(state_dict)
    seq_model.eval()
    seq_output, _new_cache = seq_model(input_ids=input_ids, states=states)
    
    par_config = MambaConfig(n_layers=L, d_model=D, d_state=N, expand_factor=E, d_conv=d_conv, scan_impl='seq_triton', dtype=dtype_str)
    par_model = MambaModel(par_config).cuda()
    par_model.model.load_state_dict(seq_model.state_dict())
    par_output, _new_cache = par_model(input_ids=input_ids, states=states)

    print("==== sequential output =====")
    print(seq_output)
    print("==== parallel output =====")
    print(par_output)
    print("Max diff:", (seq_output - par_output).abs().max())
    print("All close:")
    print(torch.allclose(seq_output, par_output))


def compare_forward_and_step():
    B = 1
    T = 2048
    D = 1024
    E = 2
    N = 16
    L = 1
    dtype = torch.bfloat16
    d_conv = 4
    states = MambaState.empty(L, B, D, E, N, d_conv, device='cuda', dtype=dtype)
    # cache = None
    # input_ids = torch.randint(low=0, high=1024, size=(B, T), dtype=torch.long, device='cuda')
    tokenizer = AutoTokenizer.from_pretrained('state-spaces/mamba-130m-hf')
    input_ids = tokenizer('An increasing sequence: one, two, three, four,', return_tensors='pt').input_ids.cuda()

    config = MambaConfig(n_layers=L, d_model=D, d_state=N, expand_factor=E, d_conv=d_conv, scan_impl='seq')
    model = MambaForCausalLM(config).cuda()
    model.eval()
    _loss, fw_logits, _new_cache = model.forward(input_ids=input_ids, states=None)
    
    # Steps
    cache = None
    for input_id in input_ids:
        step_logits, cache = model.step(input_ids=input_id, states=states)  # (B, V)
    
    print("===== fw logits =====")
    print(fw_logits)
    print("===== step logits =====")
    print(step_logits)


if __name__ == "__main__":
    compare_seq_and_par()
    # compare_forward_and_step()
