import itertools
import os
import pathlib
from dataclasses import dataclass

import torch
import torch.autograd
import torch.onnx
import torch.utils.checkpoint
import transformers
from peft import LoraConfig, TaskType, prepare_model_for_kbit_training, PeftModel, get_peft_model
from transformers import LlamaConfig, AutoTokenizer
from transformers.integrations.deepspeed import HfDeepSpeedConfig

from ..dataset.redpajama import RedPajamaDataset
from .llama3 import LlamaForCausalLM, LlamaSdpaAttention


@dataclass
class ModelConfig:
    local_rank: int = None
    using_fsdp: bool = False
    using_deepspeed: bool = False
    model_parallel: bool = False
    model: str = 'meta-llama/Llama-2-7b-hf'
    tokenizer_id: str = None
    hf_token: str = None
    use_lora: bool = False
    dense_layers: int = 3
    hip_block_size_q: int = 16
    hip_block_skip_q: int = 1
    hip_block_size_k: int = 2
    hip_top_k_elems: int = 1024
    hip_group_size_q: int = 1
    start_sink_tokens: int = 32
    end_sink_tokens: int = 32
    split_offset: int = 0
    mix_mode: str = 'hip_orig_mask'
    use_infinigen: bool = False
    infinigen_topk_ratio: float = 0.3
    lora_r: int = 32
    init_from_checkpoint: str = None
    disable_rope_scaling: bool = False
    rope_scale_factor: float = 16.0
    disable_hip: bool = False
    use_quantization: bool = False
    gptq: bool = False
    simulated_cache_size: int = None
    dtype: str = 'bfloat16'


dschf = None


def load_model(model_config: ModelConfig, for_gptq=False, for_training=True, ds_config=None):
    if ds_config is not None:
        global dschf
        dschf = HfDeepSpeedConfig(ds_config)  # keep this object alive
        device_map = None
        max_memory = None

    else:
        max_memory = None

        if os.environ.get('LOCAL_RANK', None) is not None:
            model_config.local_rank = int(os.environ['LOCAL_RANK'])

        if model_config.local_rank is not None:
            if model_config.model_parallel:
                device_map = "auto"
                max_memory = {
                    model_config.local_rank * model_config.model_parallel + i: "70GiB"
                    for i in range(model_config.model_parallel)
                }
            else:
                device_map = {"": f'cuda:{model_config.local_rank}'}
        elif model_config.model_parallel:
            device_map = "auto"
        else:
            device_map = {"": torch.cuda.current_device()}
        if model_config.using_fsdp:
            device_map = 'cpu'

    print("Device map:", device_map, "max_memory:", max_memory)

    model_id = model_config.model

    ConfigClass = LlamaConfig

    kwargs = {}
    if not model_config.disable_rope_scaling:
        kwargs['rope_scaling'] = {
            "factor": model_config.rope_scale_factor,
            "original_max_position_embeddings": 4096,
            "type": "dynamic",
        }
    config = ConfigClass.from_pretrained(
        model_id,
        token=model_config.hf_token,
        **kwargs,
    )
    config._attn_implementation = config.attn_implementation = 'sdpa'

    if not model_config.use_lora and model_config.init_from_checkpoint is not None:
        print(f"Loading model from {model_config.init_from_checkpoint}")
        model_id = model_config.init_from_checkpoint

    dtype = getattr(torch, model_config.dtype)

    if model_config.gptq:
        from .llama3_gptq import HiPLlamaGPTQForCausalLM
        model = HiPLlamaGPTQForCausalLM.from_quantized(
            model_id,
            config=config,
            device_map=device_map,
            max_memory=max_memory,
            torch_dtype=torch.float16,
            token=model_config.hf_token
        )
    else:
        if model_config.use_quantization:
            quant_config = transformers.BitsAndBytesConfig(
                load_in_4bit=True,
                llm_int8_skip_modules=['tree_avgpool_scaler', 'lm_head'],
                bnb_4bit_compute_dtype=dtype,
                bnb_4bit_use_double_quant=True,
                bnb_4bit_quant_type="nf4",
            )
        else:
            quant_config = None
        if for_training:
            if not model_config.use_lora or model_config.using_fsdp:
                print("Warning: disabling quantization for non-lora training")
                quant_config = None

        ModelClass = LlamaForCausalLM
        model = ModelClass.from_pretrained(
            model_id,
            config=config,
            device_map=device_map,
            max_memory=max_memory,
            quantization_config=quant_config,
            torch_dtype=dtype,
            trust_remote_code=True,
            token=model_config.hf_token
        )

    layer_idx = 0
    for m in model.modules():
        if isinstance(m, LlamaSdpaAttention):
            if model_config.use_infinigen:
                m.use_infinigen = True
                m.infinigen_topk_ratio = model_config.infinigen_topk_ratio
                m.hip_top_k_elems = model_config.hip_top_k_elems
                m.end_sink_tokens = model_config.end_sink_tokens
                m.simulated_cache_size = model_config.simulated_cache_size
            if layer_idx >= model_config.dense_layers:
                if model_config.use_infinigen:
                    m.use_dense = False
                else:
                    m.use_dense = model_config.disable_hip
                    m.hip_block_size_q = model_config.hip_block_size_q
                    m.hip_block_skip_q = model_config.hip_block_skip_q
                    m.hip_block_size_k = model_config.hip_block_size_k
                    m.hip_top_k_elems = model_config.hip_top_k_elems
                    m.split_offset = model_config.split_offset
                    m.mix_mode = model_config.mix_mode
                    m.start_sink_tokens = model_config.start_sink_tokens
                    m.end_sink_tokens = model_config.end_sink_tokens
                    m.hip_group_size_q = model_config.hip_group_size_q
            layer_idx += 1

        if hasattr(m, 'gradient_checkpointing'):
            m.gradient_checkpointing = True
            if model_config.using_fsdp:
                m._gradient_checkpointing_func = torch.utils.checkpoint.checkpoint
            elif model_config.using_deepspeed:
                import deepspeed
                m._gradient_checkpointing_func = deepspeed.checkpointing.checkpoint
            else:
                m._gradient_checkpointing_func = torch.utils.checkpoint.checkpoint

    if model_config.use_lora:
        peft_config = LoraConfig(
            task_type=TaskType.CAUSAL_LM,
            inference_mode=False,
            r=model_config.lora_r,
            lora_alpha=model_config.lora_r // 2,
            lora_dropout=0.05,
            target_modules=['q_proj', 'k_proj', 'v_proj'
                            'o_proj', 'gate_proj', 'up_proj', 'down_proj'],
            modules_to_save=['embed_tokens', 'lm_head'],
        )

        if for_training and model_config.use_quantization:
            model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)

        if model_config.init_from_checkpoint is not None:
            print(f"Loading peft model from {model_config.init_from_checkpoint}")

            if pathlib.Path(model_config.init_from_checkpoint).is_dir():
                model = PeftModel.from_pretrained(model, model_config.init_from_checkpoint)
                model.print_trainable_parameters()

            else:
                model = get_peft_model(model, peft_config)
                model.print_trainable_parameters()

                print('loading from', model_config.init_from_checkpoint)
                state_dict = torch.load(model_config.init_from_checkpoint, map_location='cpu')
                if 'state_dict' in state_dict:
                    state_dict = state_dict['state_dict']
                keys = list(state_dict.keys())
                for key in keys:
                    x = state_dict[key]
                    state_dict[key.strip('model.')] = x
                    del state_dict[key]
                try:
                    result = model.load_state_dict(state_dict, strict=False)
                    print('load result', result)
                except RuntimeError as e:
                    pass
                print('lora checkpoint loaded from', model_config.init_from_checkpoint)

        else:
            model = get_peft_model(model, peft_config)
            model.print_trainable_parameters()

    if for_gptq:
        from auto_gptq import BaseQuantizeConfig
        from auto_gptq.modeling import LlamaGPTQForCausalLM
        quantize_config = BaseQuantizeConfig(
            bits=4,  # quantize model to 4-bit
            group_size=128,  # it is recommended to set the value to 128
            damp_percent=0.1,
            desc_act=True,  # set to False can significantly speed up inference but the perplexity may slightly bad
            static_groups=False,
            sym=True,
            true_sequential=True,
        )

        model = LlamaGPTQForCausalLM(model, False, quantize_config)

    if model_config.tokenizer_id is None:
        model_config.tokenizer_id = model_id
    tokenizer = AutoTokenizer.from_pretrained(model_config.tokenizer_id, token=model_config.hf_token)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = 'right'

    if model_config.use_infinigen:
        with torch.no_grad():
            # Initialize skewing
            print("Initializing Infinigen")

            dataset = RedPajamaDataset(tokenizer, 4096)
            examples = [{'input_ids': item[0], "attention_mask": torch.ones_like(item[0])}
                        for item in itertools.islice(dataset, 1)]
            input_ids = torch.nn.utils.rnn.pad_sequence(
                [example['input_ids'] for example in examples],
                batch_first=True,
                padding_value=tokenizer.pad_token_id,
            )

            model(
                input_ids=input_ids.cuda(),
                compute_infinigen_skew_matrix=True
            )

    return model, tokenizer

