from transformers import PretrainedConfig

from distflow.utils.extras.device import get_torch_device

VALID_CONFIG_TYPE = {
    "llama",
    "qwen2",
    "qwen2_vl",
    "qwen2_5_vl",
    "internvl_chat",
    "qwen3",
    "qwen3_moe",
    "deepseek_v3",
    "openvla",
    "openvla-oft",
}


def get_device_flops(unit="T"):
    def unit_convert(number, level):
        units = ["B", "K", "M", "G", "T", "P"]
        if number <= 0:
            return number
        ptr = 0
        while ptr < len(units) and units[ptr] != level:
            number /= 1000
            ptr += 1
        return number

    device_name = get_torch_device().get_device_name()
    flops = float("inf")  # INF flops for unkown gpu type

    if "MI300X" in device_name:
        flops = 1336e12
    elif "H100" in device_name or "H800" in device_name or "H200" in device_name:
        flops = 989e12
    elif "A100" in device_name or "A800" in device_name:
        flops = 312e12
    elif "L40" in device_name:
        flops = 181.05e12
    elif "L20" in device_name:
        flops = 119.5e12
    elif "H20" in device_name:
        flops = 148e12
    elif "910B" in device_name:
        flops = 354e12
    flops_unit = unit_convert(flops, unit)
    return flops_unit


class FlopsCounter:
    """
    Used to count mfu during training loop

    Example:
        flops_counter = FlopsCounter(config)
        flops_achieved, flops_promised = flops_counter.estimate_flops(tokens_list, delta_time)

    """

    def __init__(self, config: PretrainedConfig, forward_only: bool = False):
        if config.model_type not in VALID_CONFIG_TYPE:
            print(
                f"Only support config type of {VALID_CONFIG_TYPE}, but got {config.model_type}. MFU will always be zero."
            )

        self.estimate_func = {
            "qwen2": self._estimate_qwen2_flops,
            "llama": self._estimate_qwen2_flops,
            "qwen2_vl": self._estimate_qwen2_flops,
            "qwen2_5_vl": self._estimate_qwen2_flops,
            "internvl_chat": self._estimate_internvl_flops,
            "qwen3": self._estimate_qwen2_flops,
            "qwen3_moe": self._estimate_qwen3_moe_flops,
            "deepseek_v3": self._estimate_deepseek_v3_flops,
            "openvla": self._estimate_openvla_flops,
            "openvla-oft": self._estimate_openvla_flops,
        }
        self.config = config
        self.forward_only = forward_only
        self.scaling_law_coff = 2 if self.forward_only else 6

    def _estimate_unknown_flops(self, tokens_sum, batch_seqlens, delta_time):
        return 0

    def _estimate_qwen2_flops(self, tokens_sum, batch_seqlens, delta_time):
        hidden_size = self.config.hidden_size
        vocab_size = self.config.vocab_size
        num_hidden_layers = self.config.num_hidden_layers
        num_key_value_heads = self.config.num_key_value_heads
        num_attention_heads = self.config.num_attention_heads
        intermediate_size = self.config.intermediate_size

        head_dim = getattr(
            self.config,
            "head_dim",
            self.config.hidden_size // self.config.num_attention_heads,
        )
        q_size = num_attention_heads * head_dim
        k_size = num_key_value_heads * head_dim
        v_size = num_key_value_heads * head_dim

        # non-attn per layer parm
        # Qwen2/LLama use SwiGelu, gate, having up and down linear layer in mlp
        mlp_N = hidden_size * intermediate_size * 3
        attn_linear_N = hidden_size * (
            q_size + k_size + v_size + num_attention_heads * head_dim
        )
        emd_and_lm_head_N = vocab_size * hidden_size * 2
        # non-attn all_layer parm
        dense_N = (mlp_N + attn_linear_N) * num_hidden_layers + emd_and_lm_head_N
        # non-attn all_layer & all_token fwd & bwd flops
        dense_N_flops = self.scaling_law_coff * dense_N * tokens_sum

        # attn all_layer & all_token fwd & bwd flops
        seqlen_square_sum = 0
        for seqlen in batch_seqlens:
            seqlen_square_sum += seqlen * seqlen
        attn_qkv_flops = (
            12 * seqlen_square_sum * head_dim * num_attention_heads * num_hidden_layers
        )

        # all_layer & all_token fwd & bwd flops
        flops_all_token = dense_N_flops + attn_qkv_flops
        flops_achieved = flops_all_token * (1.0 / delta_time) / 1e12
        return flops_achieved

    def _estimate_internvl_flops(self, tokens_sum, batch_seqlens, delta_time):
        # TODO consider vit
        hidden_size = self.config.llm_config.hidden_size
        vocab_size = self.config.llm_config.vocab_size
        num_hidden_layers = self.config.llm_config.num_hidden_layers
        num_key_value_heads = self.config.llm_config.num_key_value_heads
        num_attention_heads = self.config.llm_config.num_attention_heads
        intermediate_size = self.config.llm_config.intermediate_size

        head_dim = getattr(
            self.config,
            "head_dim",
            self.config.hidden_size // self.config.num_attention_heads,
        )
        q_size = num_attention_heads * head_dim
        k_size = num_key_value_heads * head_dim
        v_size = num_key_value_heads * head_dim

        # non-attn per layer parm
        # Qwen2/LLama use SwiGelu, gate, having up and down linear layer in mlp
        mlp_N = hidden_size * intermediate_size * 3
        attn_linear_N = hidden_size * (
            q_size + k_size + v_size + num_attention_heads * head_dim
        )
        emd_and_lm_head_N = vocab_size * hidden_size * 2
        # non-attn all_layer parm
        dense_N = (mlp_N + attn_linear_N) * num_hidden_layers + emd_and_lm_head_N
        # non-attn all_layer & all_token fwd & bwd flops
        dense_N_flops = 6 * dense_N * tokens_sum

        # attn all_layer & all_token fwd & bwd flops
        seqlen_square_sum = 0
        for seqlen in batch_seqlens:
            seqlen_square_sum += seqlen * seqlen
        attn_qkv_flops = (
            12 * seqlen_square_sum * head_dim * num_attention_heads * num_hidden_layers
        )

        # all_layer & all_token fwd & bwd flops
        flops_all_token = dense_N_flops + attn_qkv_flops
        flops_achieved = flops_all_token * (1.0 / delta_time) / 1e12
        return flops_achieved

        # attn all_layer & all_token fwd & bwd flops
        seqlen_square_sum = 0
        for seqlen in batch_seqlens:
            seqlen_square_sum += seqlen * seqlen
        attn_qkv_flops = (
            12 * seqlen_square_sum * head_dim * num_attention_heads * num_hidden_layers
        )

        # all_layer & all_token fwd & bwd flops
        flops_all_token = dense_N_flops + attn_qkv_flops
        flops_achieved = flops_all_token * (1.0 / delta_time) / 1e12
        return flops_achieved

    def _estimate_deepseek_v3_flops(self, tokens_sum, batch_seqlens, delta_time):
        hidden_size = self.config.hidden_size
        vocab_size = self.config.vocab_size
        moe_intermediate_size = self.config.moe_intermediate_size
        num_hidden_layers = self.config.num_hidden_layers
        first_k_dense_replace = self.config.first_k_dense_replace
        num_query_heads = self.config.num_attention_heads
        moe_num_expert = self.config.n_routed_experts

        moe_topk = self.config.num_experts_per_tok
        share_expert_num = self.config.n_shared_experts

        # non-attn per layer parm
        moe_gata_N = hidden_size * moe_num_expert
        # moe has fc1_1, fc1_2 and fc2 using SwiGLU in ExpertMlp layer & shared experts
        moe_expertmlp_N = (
            hidden_size * moe_intermediate_size * (moe_topk + share_expert_num) * 3
        )
        # MLA attn
        attn_linear_N = 0
        q_head_dim = self.config.qk_nope_head_dim + self.config.qk_rope_head_dim
        if self.config.q_lora_rank is None:
            attn_linear_N += hidden_size * num_query_heads * q_head_dim
        else:
            attn_linear_N += hidden_size * self.config.q_lora_rank
            attn_linear_N += num_query_heads * q_head_dim * self.config.q_lora_rank

        attn_linear_N += hidden_size * (
            self.config.kv_lora_rank + self.config.qk_rope_head_dim
        )
        attn_linear_N += (
            num_query_heads
            * (q_head_dim - self.config.qk_rope_head_dim + self.config.v_head_dim)
            * self.config.kv_lora_rank
        )
        attn_linear_N += num_query_heads * self.config.v_head_dim * hidden_size
        emd_and_lm_head_N = vocab_size * hidden_size * 2
        # non-attn all_layer parm
        moe_N = (
            (moe_gata_N + moe_expertmlp_N + attn_linear_N)
            * (num_hidden_layers - first_k_dense_replace)
            + (hidden_size * self.config.intermediate_size * 3 + attn_linear_N)
            * first_k_dense_replace
            + emd_and_lm_head_N
        )
        # non-attn all_layer & all_token fwd & bwd flops
        dense_N_flops = 6 * moe_N * tokens_sum

        # attn all_layer & all_token fwd & bwd flops
        seqlen_square_sum = 0
        for seqlen in batch_seqlens:
            seqlen_square_sum += seqlen * seqlen * num_hidden_layers

        attn_qkv_flops = 12 * seqlen_square_sum * q_head_dim * num_query_heads
        # all_layer & all_token fwd & bwk flops
        flops_all_token = dense_N_flops + attn_qkv_flops
        flops_achieved = flops_all_token * (1.0 / delta_time) / 1e12

        return flops_achieved

    def _estimate_qwen3_moe_flops(self, tokens_sum, batch_seqlens, delta_time):
        hidden_size = self.config.hidden_size
        vocab_size = self.config.vocab_size
        num_hidden_layers = self.config.num_hidden_layers
        num_key_value_heads = self.config.num_key_value_heads
        num_attention_heads = self.config.num_attention_heads
        moe_intermediate_size = self.config.moe_intermediate_size
        moe_topk = self.config.num_experts_per_tok
        num_experts = self.config.num_experts

        head_dim = getattr(
            self.config,
            "head_dim",
            self.config.hidden_size // self.config.num_attention_heads,
        )
        q_size = num_attention_heads * head_dim
        k_size = num_key_value_heads * head_dim
        v_size = num_key_value_heads * head_dim

        # non-attn per layer parm
        # gate + moe export
        moe_mlp_N = (
            hidden_size * moe_topk * moe_intermediate_size * 3
            + hidden_size * num_experts
        )
        attn_linear_N = hidden_size * (
            q_size + k_size + v_size + num_attention_heads * head_dim
        )
        emd_and_lm_head_N = vocab_size * hidden_size * 2
        # non-attn all_layer parm
        dense_N = (moe_mlp_N + attn_linear_N) * num_hidden_layers + emd_and_lm_head_N
        # non-attn all_layer & all_token fwd & bwd flops
        dense_N_flops = 6 * dense_N * tokens_sum

        # attn all_layer & all_token fwd & bwd flops
        seqlen_square_sum = 0
        for seqlen in batch_seqlens:
            seqlen_square_sum += seqlen * seqlen
        attn_qkv_flops = (
            12 * seqlen_square_sum * head_dim * num_attention_heads * num_hidden_layers
        )

        # all_layer & all_token fwd & bwd flops
        flops_all_token = dense_N_flops + attn_qkv_flops
        flops_achieved = flops_all_token * (1.0 / delta_time) / 1e12
        return flops_achieved

    def _estimate_openvla_flops(self, tokens_sum, batch_seqlens, delta_time):
        """
        Estimate FLOPs for OpenVLA/OpenVLA-OFT models.
        
        OpenVLA architecture:
        - Vision encoder (frozen, not counted)
        - Projector MLP (vision_dim -> llm_dim)
        - Language model (typically Llama/Vicuna-based)
        
        The main computation comes from the language model backbone.
        """
        # Access LLM config from text_config (nested config)
        if not hasattr(self.config, 'text_config'):
            # Fallback to zero if text_config is not available
            print("Warning: OpenVLA config missing text_config, cannot estimate FLOPs")
            return 0
        
        llm_config = self.config.text_config
        
        # Extract LLM parameters
        hidden_size = llm_config.hidden_size
        vocab_size = llm_config.vocab_size
        num_hidden_layers = llm_config.num_hidden_layers
        num_key_value_heads = llm_config.num_key_value_heads
        num_attention_heads = llm_config.num_attention_heads
        intermediate_size = llm_config.intermediate_size

        head_dim = getattr(
            llm_config,
            "head_dim",
            llm_config.hidden_size // llm_config.num_attention_heads,
        )
        q_size = num_attention_heads * head_dim
        k_size = num_key_value_heads * head_dim
        v_size = num_key_value_heads * head_dim

        # Language model FLOPs (similar to LLaMA)
        # Llama uses SwiGLU, having up and down linear layers in MLP
        mlp_N = hidden_size * intermediate_size * 3
        attn_linear_N = hidden_size * (
            q_size + k_size + v_size + num_attention_heads * head_dim
        )
        
        # For OpenVLA, we use vocab_size from LLM config
        # Note: OpenVLA uses action tokens, but they're part of the vocab
        emd_and_lm_head_N = vocab_size * hidden_size * 2
        
        # Total dense parameters across all layers
        dense_N = (mlp_N + attn_linear_N) * num_hidden_layers + emd_and_lm_head_N
        
        # Dense layer FLOPs (forward + backward)
        dense_N_flops = self.scaling_law_coff * dense_N * tokens_sum

        # Attention FLOPs
        seqlen_square_sum = 0
        for seqlen in batch_seqlens:
            seqlen_square_sum += seqlen * seqlen
        attn_qkv_flops = (
            12 * seqlen_square_sum * head_dim * num_attention_heads * num_hidden_layers
        )

        # Total FLOPs
        flops_all_token = dense_N_flops + attn_qkv_flops
        flops_achieved = flops_all_token * (1.0 / delta_time) / 1e12
        return flops_achieved

    def estimate_flops(self, batch_seqlens, delta_time):
        """
        Estimate the FLOPS based on the number of valid tokens in the current batch and the time taken.

        Args:
            batch_seqlens (List[int] or int): A list where each element represents the number of valid tokens in the current batch,
                or a single integer representing total tokens. Can also handle nested lists.
            delta_time (float): The time taken to process the batch, in seconds.

        Returns:
            estimated_flops (float): The estimated FLOPS based on the input tokens and time.
            promised_flops (float): The expected FLOPS of the current device.
        """
        # Normalize batch_seqlens to a flat list of integers
        def flatten_to_ints(data):
            """Recursively flatten nested lists/tuples to a flat list of integers."""
            if isinstance(data, (int, float)):
                return [int(data)]
            elif isinstance(data, (list, tuple)):
                result = []
                for item in data:
                    result.extend(flatten_to_ints(item))
                return result
            else:
                # If it's some other type (e.g., tensor), try to convert
                try:
                    return flatten_to_ints(list(data))
                except:
                    # Fallback: treat as single item
                    return [int(data)]
        
        batch_seqlens_flat = flatten_to_ints(batch_seqlens)
        tokens_sum = sum(batch_seqlens_flat)
        
        # Use the flattened list for further processing
        batch_seqlens = batch_seqlens_flat
        func = self.estimate_func.get(
            self.config.model_type, self._estimate_unknown_flops
        )
        estimated_flops = func(tokens_sum, batch_seqlens, delta_time)
        promised_flops = get_device_flops()
        return estimated_flops, promised_flops
