import torch
import torch.nn as nn
import torch.cuda
import json
import os
import numpy as np
import random
from pathlib import Path
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from mamba_ssm.models.config_mamba import MambaConfig
from transformers import AutoTokenizer

from qat.replace_module import replace_with_learnable_binarylinear

def load_json(data_p):
    with Path(data_p).open('r', encoding='utf-8') as r_f:
        json_data = json.load(r_f)
    return json_data

def save_json(json_data, save_p):
    with Path(save_p).open('w', encoding='utf-8') as w_f:
        json.dump(json_data, w_f, ensure_ascii=False, indent=4)

def set_random_seed(seed):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def print_memory_usage():
    mem = torch.cuda.memory_allocated()
    print(f"memory_allocated: {mem / 1024 / 1024} MB")


def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for key, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            print('+', key)
            trainable_params += param.numel()
        else:
            print('-', key)
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )


def prepare_model_for_training(model):
    for name, param in model.named_parameters():
        # freeze base model's layers
        param.requires_grad = False

    for param in model.parameters():
        if (param.dtype == torch.float16) or (param.dtype == torch.bfloat16):
            param.data = param.data.to(torch.float32)

    # For backward compatibility
    if hasattr(model, "enable_input_require_grads"):
        model.enable_input_require_grads()
    else:
        def make_inputs_require_grad(module, input, output):
            output.requires_grad_(True)

        model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)

    # enable gradient checkpointing for memory efficiency
    # model.gradient_checkpointing_enable()
    return model


def prepare_model_for_eval(model):
    model.eval()

    for name, param in model.named_parameters():
        # freeze base model's layers
        param.requires_grad = False

    for param in model.parameters():
        param.data = param.data.to(torch.float16)
    return model


def get_bnn_meta(model):
    meta = {}
    for name, module in model.named_modules():
        if isinstance(module, quant.BinaryInterface):
            meta[name] = module.__class__.__name__
    return meta


def get_bnn_weights(model):
    weights = {}
    for name, module in model.named_modules():
        if isinstance(module, quant.BinaryInterface):
            layer_weight_dict = module.get_save_weight_dict()
            layer_weight_dict = {
                name + "_" + k: v for k, v in layer_weight_dict.items()
            }
            weights.update(layer_weight_dict)
            # weights[name] = module.weight.data.half().cpu()
            # weights[name + "_bias"] = module.bias
    return weights


def save_bnn(model, save_path):
    print(f"saving bnn model to {save_path}")
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    meta = get_bnn_meta(model)
    weights = get_bnn_weights(model)
    json.dump(meta, open(save_path + "/meta.json", "w"))
    torch.save(weights, save_path + "/weights.pth")


def load_bnn(model, load_path):
    print(f"loading bnn model from {load_path}")
    bnn_meta = json.load(open(load_path + "/meta.json", "r"))
    bnn_weights = torch.load(load_path + "/weights.pth")
    print(bnn_weights.keys())

    module_name_dict = {name: module for name, module in model.named_modules()}
    for name, module in module_name_dict.items():
        if isinstance(module, nn.Linear):
            ind = name.rfind(".")
            if ind == -1:
                father = module_name_dict[""]
            else:
                father = module_name_dict[name[:ind]]
            # choose binariztaion method
            if name in bnn_meta:
                binarization_method = bnn_meta[name]
                weight = bnn_weights[name + "_weight"]
                # weight = bnn_weights[name]
                bias = bnn_weights[name + "_bias"]
                # weight=weight.to(module.weight.device)
                # if bias is not None:
                #     bias=bias.to(module.weight.device)
                qlinear = getattr(quant, binarization_method)(weight, bias)

                setattr(father, name[ind + 1 :], qlinear)
                print(f"replace layer {name} with {qlinear}")
    return model


def generate_sample_test(model, tokenizer):
    # generate a sample
    # prompt = "Hey, are you conscious? Can you talk to me?"
    prompt = "Hey, is llama the best language model?"
    inputs = tokenizer(prompt, return_tensors="pt")
    generate_ids = model.generate(inputs.input_ids, max_length=60)
    outputs = tokenizer.batch_decode(
        generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )[0]
    print(outputs)


def load_mamba2(model_name):
    model = MambaLMHeadModel.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
    return model, tokenizer


def load_bimamba2_ckpts(model_size, ckpt_dir, exist_extra_para, keep_parts, scaling_pattern):
    assert model_size in ["780M", "1.3B", "3B"]

    ckpt_dir = Path(ckpt_dir)
    if model_size == '3B':
        config = MambaConfig(
            d_model = 2560,
            vocab_size=32000,
            d_intermediate = 0,
            n_layer = 64,
            ssm_cfg={"layer": "Mamba2"},
            attn_cfg = {},
            attn_layer_idx = [],
            pad_vocab_size_multiple=16
            )
    elif model_size == '1.3B':
        config = MambaConfig(
            d_model = 2048,
            d_intermediate = 0,
            n_layer = 48,
            vocab_size=32000,
            ssm_cfg={"layer": "Mamba2"},
            attn_cfg = {},
            attn_layer_idx = [],
            pad_vocab_size_multiple=16,
            tie_embeddings = True
            )
    elif model_size == '780M':
        config = MambaConfig(
            d_model = 1536,
            d_intermediate = 0,
            n_layer = 48,
            vocab_size=32000,
            ssm_cfg={"layer": "Mamba2"},
            attn_cfg = {},
            attn_layer_idx = [],
            pad_vocab_size_multiple=16,
            tie_embeddings = True
            )
    elif model_size == '370M':
        config = MambaConfig(
            d_model = 1024,
            d_intermediate = 0,
            n_layer = 48,
            vocab_size=32000,
            ssm_cfg={"layer": "Mamba2"},
            attn_cfg = {},
            attn_layer_idx = [],
            pad_vocab_size_multiple=16,
            tie_embeddings = True
            )

    # model = MambaForCausalLM(config=config)
    print(model_size)
    print(config)
    model = MambaLMHeadModel(config).to('cuda')
    tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")

    if exist_extra_para:
        model = replace_with_learnable_binarylinear(model, scaling_pattern, keep_parts)

    weight_dict = {}
    ckpt_plist = [p for p in ckpt_dir.iterdir() if p.suffix == '.bin']
    for p in ckpt_plist:
        _weight_dict = torch.load(p)
        for k,v in _weight_dict.items():
            weight_dict[k] = v

    model.load_state_dict(weight_dict)

    return model, tokenizer
