

import os
import re
import warnings
from dataclasses import dataclass
from typing import Optional

import numpy as np
import torch
from torch import nn
from transformers import (
    AutoConfig,
    AutoModelForCausalLM,
    GenerationConfig,
    MistralForSequenceClassification,
    PretrainedConfig,
    PreTrainedModel,
)
from transformers.modeling_outputs import CausalLMOutputWithPast

from verl.models.registry import ModelRegistry
from verl.utils.import_utils import is_trl_available

class LambdaLayer(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, *args, **kwargs):
        return self.fn(*args, **kwargs)

def squeeze(x):
    return torch.squeeze(x, dim=-1)

def update_model_config(module_config, override_config_kwargs):
    for key, val in override_config_kwargs.items():
        if isinstance(val, dict):
            update_model_config(getattr(module_config, key), val)
        else:
            setattr(module_config, key, val)

def get_huggingface_actor_config(model_name: str, override_config_kwargs=None, trust_remote_code=False) -> dict:
    if override_config_kwargs is None:
        override_config_kwargs = {}
    assert isinstance(override_config_kwargs, dict), (
        f"override_config_kwargs must be a dict, got {type(override_config_kwargs)}"
    )
    module_config = AutoConfig.from_pretrained(model_name, trust_remote_code=trust_remote_code)
    update_model_config(module_config, override_config_kwargs)

    return module_config

def get_generation_config(
    model: str,
    trust_remote_code: bool = False,
) -> Optional[GenerationConfig]:
    try:
        return GenerationConfig.from_pretrained(model)
    except OSError:
        try:
            config = get_huggingface_actor_config(
                model,
                trust_remote_code=trust_remote_code,
            )
            return GenerationConfig.from_model_config(config)
        except OSError:
            return None

def create_huggingface_actor(model_name: str, override_config_kwargs=None, automodel_kwargs=None) -> nn.Module:
    if override_config_kwargs is None:
        override_config_kwargs = {}
    if automodel_kwargs is None:
        automodel_kwargs = {}
    assert isinstance(override_config_kwargs, dict), (
        f"override_config_kwargs must be a dict, got {type(override_config_kwargs)}"
    )
    module_config = get_huggingface_actor_config(
        model_name, override_config_kwargs, trust_remote_code=automodel_kwargs.get("trust_remote_code", False)
    )
    module: nn.Module = AutoModelForCausalLM.from_config(module_config, **automodel_kwargs)
    return module

def create_huggingface_critic(model_name: str, override_config_kwargs=None, automodel_kwargs=None) -> nn.Module:
    critic_module: nn.Module = create_huggingface_actor(
        model_name, override_config_kwargs=override_config_kwargs, automodel_kwargs=automodel_kwargs
    )
    if automodel_kwargs is None:
        automodel_kwargs = {}
    torch_dtype = automodel_kwargs.get("torch_dtype", torch.float32)
    critic_module.lm_head = nn.Sequential(
        nn.Linear(critic_module.config.hidden_size, 1, dtype=torch_dtype), LambdaLayer(fn=squeeze)
    )
    return critic_module

def get_model_size(model: nn.Module, scale="auto"):
    n_params = sum(p.numel() for p in model.parameters())

    if scale == "auto":
        if n_params > 1e9:
            scale = "B"
        elif n_params > 1e6:
            scale = "M"
        elif n_params > 1e3:
            scale = "K"
        else:
            scale = ""

    if scale == "B":
        n_params = n_params / 1e9
    elif scale == "M":
        n_params = n_params / 1e6
    elif scale == "K":
        n_params = n_params / 1e3
    elif scale == "":
        pass
    else:
        raise NotImplementedError(f"Unknown scale {scale}")

    return n_params, scale

def print_model_size(model: nn.Module, name: str = None):
    n_params, scale = get_model_size(model, scale="auto")
    if name is None:
        name = model.__class__.__name__
    print(f"{name} contains {n_params:.2f}{scale} parameters")

def create_random_mask(
    input_ids: torch.Tensor,
    max_ratio_of_valid_token: float,
    max_ratio_of_left_padding: float,
    min_ratio_of_valid_token: float = 0,
):
    assert max_ratio_of_valid_token > 0 and max_ratio_of_valid_token <= 1.0
    assert max_ratio_of_left_padding >= 0 and max_ratio_of_left_padding < 1.0
    assert min_ratio_of_valid_token <= max_ratio_of_valid_token

    batch_size, sequence_length = input_ids.shape
    max_num_valid_tokens = int(sequence_length * max_ratio_of_valid_token)
    min_num_valid_tokens = max(1, int(sequence_length * min_ratio_of_valid_token))
    max_left_padding = int(sequence_length * max_ratio_of_left_padding)
    assert max_num_valid_tokens + max_left_padding <= sequence_length
    assert max_num_valid_tokens > 0 and max_ratio_of_valid_token <= sequence_length
    masks = torch.ones_like(input_ids, dtype=torch.int64)

    for i in range(batch_size):
        num_left_padding = np.random.randint(low=0, high=max_left_padding + 1, dtype=np.int64)
        num_valid = np.random.randint(low=min_num_valid_tokens, high=max_num_valid_tokens + 1, dtype=np.int64)

        for index in range(num_left_padding):
            masks[i, index] = 0

        for index in range(num_left_padding + num_valid, sequence_length):
            masks[i, index] = 0
    return masks

def compute_position_id_with_mask(mask):
    return torch.clip(torch.cumsum(mask, dim=-1) - 1, min=0, max=None)

def convert_weight_keys(state_dict: dict[str, torch.Tensor], model: PreTrainedModel):

    if not hasattr(model, "_checkpoint_conversion_mapping"):
        return state_dict

    reverse_key_mapping = {v: k for k, v in model._checkpoint_conversion_mapping.items()}
    original_weights = {}
    for key, value in state_dict.items():
        for pattern, replacement in reverse_key_mapping.items():
            replacement = replacement.lstrip("^")
            replacement = re.sub(r"\(.*\)", "", replacement)
            key, n_replace = re.subn(pattern, replacement, key)

            if n_replace > 0:
                break

        original_weights[key] = value

    return original_weights

def check_exclude_modules(config, key: str) -> bool:
    if hasattr(config, "exclude_modules") and config.exclude_modules:
        if isinstance(config.exclude_modules, str):
            if re.fullmatch(config.exclude_modules, key):
                return True
        elif key in config.exclude_modules:
            return True
        elif any(key.endswith(f".{exclude_key}") for exclude_key in config.exclude_modules):
            return True
    return False

def check_target_modules(config, key: str) -> bool:
    if isinstance(config.target_modules, str):
        target_module_found = re.fullmatch(config.target_modules, key)
    elif key in config.target_modules:

        target_module_found = True
    else:
        target_module_found = any(key.endswith(f".{target_key}") for target_key in config.target_modules)

        layer_indexes = getattr(config, "layers_to_transform", None)
        layers_pattern = getattr(config, "layers_pattern", None)

        is_using_layer_indexes = layer_indexes is not None and (
            len(layer_indexes) != 0 if isinstance(layer_indexes, list) else True
        )
        if is_using_layer_indexes and target_module_found:
            layer_index = None

            if layers_pattern is None or len(layers_pattern) == 0:
                layer_index = re.match(r".*\.[^.]*\.(\d+)\.", key)
            else:
                layers_pattern = [layers_pattern] if isinstance(layers_pattern, str) else layers_pattern
                for pattern in layers_pattern:
                    layer_index = re.match(rf".*\.{pattern}\.(\d+)\.", key)
                    if layer_index is not None:
                        break

            if layer_index is None:
                target_module_found = False
            else:
                layer_index = int(layer_index.group(1))
                if isinstance(layer_indexes, int):
                    target_module_found = layer_index == layer_indexes
                else:
                    target_module_found = layer_index in layer_indexes

    return target_module_found

def normalize_model_name(name, pp_rank, vpp_rank, transformer_config, layer_name="layers"):
    from verl.utils.megatron_utils import get_transformer_layer_offset

    layer_offset = get_transformer_layer_offset(pp_rank, vpp_rank, transformer_config)

    if layer_name in name:
        split_name = name.split(".")

        for i, name in enumerate(split_name):
            if name == layer_name:
                break
        layer_num_idx = i + 1

        assert len(split_name) >= layer_num_idx + 1, f"split_name = {split_name}"
        assert split_name[layer_num_idx].isdigit(), f"split_name = {split_name}"

        split_name[layer_num_idx] = str(int(split_name[layer_num_idx]) + layer_offset)
        name = ".".join(split_name)
    return name

def normalize_pp_vpp_params(params, num_hidden_layers, layer_name="layers"):
    pp_size = len(params)
    for pp_rank in range(len(params)):
        vpp_size = len(params[pp_rank])
        for vpp_rank in range(vpp_size):
            for name, param in params[pp_rank][vpp_rank].items():
                normalized_name = normalize_model_name(
                    name, pp_rank, vpp_rank, pp_size, vpp_size, num_hidden_layers, layer_name=layer_name
                )
                yield normalized_name, param

def get_parallel_model_from_config(
    config, megatron_config, pre_process=None, post_process=None, share_embeddings_and_output_weights=False, value=False
):
    from megatron.core import ModelParallelConfig

    assert isinstance(megatron_config, ModelParallelConfig)
    model_class = _get_parallel_model_architecture_from_config(config, value)

    model = model_class(
        config,
        megatron_config,
        pre_process=pre_process,
        post_process=post_process,
        share_embeddings_and_output_weights=share_embeddings_and_output_weights,
    )
    return model

def _get_parallel_model_architecture_from_config(config: PretrainedConfig, value=False) -> type[nn.Module]:
    architectures = getattr(config, "architectures", [])
    for arch in architectures:
        model_cls = ModelRegistry.load_model_cls(arch, value)
        print("after load model cls")
        if model_cls is not None:
            return model_cls
    raise ValueError(
        f"Model architectures {architectures} are not supported for now. Supported architectures: "
        f"{ModelRegistry.get_supported_archs()}"
    )

def _load_hf_model(config, model_config, is_value_model, local_cache_path):
    from accelerate import init_empty_weights
    from megatron.core import parallel_state as mpu

    from verl.models.mcore.saver import _megatron_calc_global_rank

    assert hasattr(model_config, "architectures"), "architectures cannot be empty when load weight!"
    architectures = getattr(model_config, "architectures", [])
    local_cache_path = os.path.expanduser(local_cache_path)

    if config.model.path.startswith("hdfs:"):
        from verl.utils.fs import copy_to_local

        print(f"start download from {config.model.path}")
        local_model_path = copy_to_local(
            src=config.model.path, cache_dir=local_cache_path, use_shm=config.model.get("use_shm", False)
        )
        print("finish download")
    else:
        local_model_path = config.model.path
        print(f"load from local dir {local_model_path}")

    src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=0, cp_rank=mpu.get_context_parallel_rank())
    cpu_init_weights = lambda: torch.device("cpu")
    init_context = init_empty_weights if torch.distributed.get_rank() != src_rank else cpu_init_weights
    with init_context(), warnings.catch_warnings():
        warnings.simplefilter("ignore")

        if "mistral7b-rm" in config.model.path:
            model = MistralForSequenceClassification.from_pretrained(
                local_model_path,
                torch_dtype="auto",

            )
            state_dict = model.state_dict()
            state_dict["lm_head.weight"] = state_dict["score.weight"]
            state_dict["model.embed_tokens.weight"] = state_dict["model.embed_tokens.weight"][
                :32000
            ]
            is_value_model = True
        else:
            model = AutoModelForCausalLM.from_pretrained(
                local_model_path,
                torch_dtype="auto",

            )
            state_dict = model.state_dict()

    return architectures, model, state_dict, is_value_model

def get_hf_model_path(config, local_cache_path="~/.cache/verl/rlhf"):
    local_cache_path = os.path.expanduser(local_cache_path)
    if config.model.path.startswith("hdfs:"):
        from verl.utils.fs import copy_to_local

        local_model_path = copy_to_local(
            src=config.model.path, cache_dir=local_cache_path, use_shm=config.model.get("use_shm", False)
        )
    else:
        local_model_path = config.model.path
    return local_model_path

def load_megatron_model_weights(
    config, model_config, parallel_model, params_dtype, is_value_model=False, local_cache_path="~/.cache/verl/rlhf"
):
    architectures, model, state_dict, is_value_model = _load_hf_model(
        config, model_config, is_value_model, local_cache_path
    )

    from verl.models.weight_loader_registry import get_weight_loader

    print(f"before weight loader: architectures = {architectures}...")
    for arch in architectures:
        print(f"call weight loader arch = {arch}, model config = {model.config}")
        weight_loader = get_weight_loader(arch)
        weight_loader(
            state_dict=state_dict,
            wrapped_models=parallel_model,
            config=model.config,
            params_dtype=params_dtype,
            is_value_model=is_value_model,
            tie_word_embeddings=model_config.tie_word_embeddings,
        )
    return model.config

def load_megatron_gptmodel_weights(
    config, model_config, parallel_model, params_dtype, is_value_model=False, local_cache_path="~/.cache/verl/rlhf"
):
    _, model, state_dict, is_value_model = _load_hf_model(config, model_config, is_value_model, local_cache_path)

    from verl.models.mcore.loader import load_state_dict_to_megatron_gptmodel

    load_state_dict_to_megatron_gptmodel(
        state_dict=state_dict,
        wrapped_models=parallel_model,
        config=model.config,
        params_dtype=params_dtype,
        is_value_model=is_value_model,
    )
    del state_dict, model

def pad_packed_inputs(unpad_tokens: torch.Tensor, cu_seqlens, max_seqlen_in_batch, size):
    F = nn.functional

    total_nnz = unpad_tokens.shape[0]

    pad_size = 0 if total_nnz % size == 0 else size - total_nnz % size

    if pad_size > 0:
        if unpad_tokens.ndim == 1:
            unpad_tokens = F.pad(unpad_tokens, (0, pad_size))
        elif unpad_tokens.ndim == 2:
            unpad_tokens = F.pad(unpad_tokens, (0, 0, 0, pad_size))
        else:
            raise NotImplementedError(f"Padding dim {unpad_tokens.ndim()} is not supported")

        cu_seqlens = F.pad(cu_seqlens, (0, 1), value=pad_size + cu_seqlens[-1])
        max_seqlen_in_batch = max(max_seqlen_in_batch, pad_size)

    return unpad_tokens, cu_seqlens, max_seqlen_in_batch

def load_mcore_dist_weights(parallel_model, dist_weight_path, is_value_model=False):
    from megatron.core import dist_checkpointing
    from megatron.core.dist_checkpointing.serialization import StrictHandling

    from verl.utils.megatron_utils import unwrap_model

    strict = StrictHandling.ASSUME_OK_UNEXPECTED
    for model in parallel_model:
        ssd = unwrap_model(model).sharded_state_dict()
        if is_value_model:
            for k in list(ssd.keys()):
                if "output_layer" in k:
                    ssd.pop(k)
        dist_checkpointing.load(ssd, dist_weight_path, strict=strict)

    return

def get_parallel_gptmodel_from_config(
    tfconfig, hf_config, pre_process=None, post_process=None, share_embeddings_and_output_weights=False, value=False
):
    from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec
    from megatron.core.models.gpt.gpt_model import GPTModel

    use_te = True
    assert tfconfig.normalization == "RMSNorm", "only RMSNorm is supported for now"
    transformer_layer_spec = get_gpt_decoder_block_spec(tfconfig, use_transformer_engine=use_te)
    rope_scaling_args = {}
    if hf_config.rope_scaling is not None:
        assert hf_config.rope_scaling["type"] == "linear", "only linear scaling is supported for now"
        rope_scaling_args["seq_len_interpolation_factor"] = hf_config.rope_scaling["factor"]
    parallel_model = GPTModel(
        config=tfconfig,
        transformer_layer_spec=transformer_layer_spec,
        vocab_size=hf_config.vocab_size,
        max_sequence_length=hf_config.max_position_embeddings,
        pre_process=pre_process,
        post_process=post_process,
        share_embeddings_and_output_weights=share_embeddings_and_output_weights,
        position_embedding_type="rope",
        rotary_base=hf_config.rope_theta,
        **rope_scaling_args,
    )

    if post_process and value:
        from verl.models.llama.megatron.layers.parallel_linear import LinearForLastLayer

        parallel_model.output_layer = LinearForLastLayer(
            input_size=tfconfig.hidden_size, output_size=1, config=tfconfig
        )
    return parallel_model

def patch_valuehead_model(model) -> None:
    from types import MethodType

    from transformers import PreTrainedModel
    from trl import AutoModelForCausalLMWithValueHead

    def tie_weights(self: "AutoModelForCausalLMWithValueHead") -> None:
        if isinstance(self.pretrained_model, PreTrainedModel):
            self.pretrained_model.tie_weights()

    def get_input_embeddings(self: "AutoModelForCausalLMWithValueHead") -> torch.nn.Module:
        if isinstance(self.pretrained_model, PreTrainedModel):
            return self.pretrained_model.get_input_embeddings()

    def get_output_embeddings(self: "AutoModelForCausalLMWithValueHead") -> torch.nn.Module:
        if isinstance(self.pretrained_model, PreTrainedModel):
            return self.pretrained_model.get_output_embeddings()

    def can_generate(self):
        return False

    ignore_modules = [name for name, _ in model.named_parameters() if "pretrained_model" in name]
    model._keys_to_ignore_on_save = ignore_modules
    model.tie_weights = MethodType(tie_weights, model)
    model.get_input_embeddings = MethodType(get_input_embeddings, model)
    model.get_output_embeddings = MethodType(get_output_embeddings, model)
    model.can_generate = MethodType(can_generate, model)
    model._no_split_modules = getattr(model.pretrained_model, "_no_split_modules", [])

def load_valuehead_model(local_path, torch_dtype, model_config, trust_remote_code):
    from transformers import AutoModelForCausalLM, AutoModelForTokenClassification, AutoModelForVision2Seq

    try:
        model = AutoModelForTokenClassification.from_pretrained(
            pretrained_model_name_or_path=local_path,
            torch_dtype=torch_dtype,
            config=model_config,
            attn_implementation="flash_attention_2",
            trust_remote_code=trust_remote_code,
        )
        return model
    except BaseException as e:
        if not is_trl_available():
            raise RuntimeError(
                f"model({local_path}) is not a value head model, please install trl to make it valid"
            ) from e

    assert is_trl_available()

    from trl import AutoModelForCausalLMWithValueHead

    if type(model_config) in AutoModelForVision2Seq._model_mapping.keys():
        module_class = AutoModelForVision2Seq
    else:
        module_class = AutoModelForCausalLM
    ori_model = module_class.from_pretrained(
        pretrained_model_name_or_path=local_path,
        torch_dtype=torch_dtype,
        config=model_config,
        attn_implementation="flash_attention_2",
        trust_remote_code=trust_remote_code,
    )
    model = AutoModelForCausalLMWithValueHead.from_pretrained(ori_model)
    patch_valuehead_model(model)
    return model

@dataclass
class CausalLMOutputForPPO(CausalLMOutputWithPast):
    log_probs: Optional[torch.FloatTensor] = None
    entropy: Optional[torch.FloatTensor] = None
