import re
import time
from collections import defaultdict
from typing import Optional, Tuple, Union
import os
import pickle

import ray
import torch
import numpy as np
from loguru import logger
from omegaconf.dictconfig import DictConfig
from ray.util.placement_group import PlacementGroup, placement_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from torch.nn.utils.rnn import pad_sequence

from thinker_task.exp_engine.accelerators.inference.vllm_engine import LLMActor
from thinker_task.exp_engine.accelerators.inference.mul_llm import MulLLMActor
from thinker_task.exp_engine.accelerators.inference.sum_llm import SumLLMActor
from thinker_task.ppo.openrlhf_deepspeed import DeepspeedStrategy

LLMRayActor = ray.remote(LLMActor)
SumLLMRayActor = ray.remote(SumLLMActor)
MulLLMRayActor = ray.remote(MulLLMActor)

def get_train_ds_config(
    offload,
    adam_offload=True,
    stage=2,
    bf16=True,
    max_norm=1.0,
    zpg=8,
    grad_accum_dtype=None,
    disable_trace_cache=False,
):
    device = "cpu" if offload else "none"
    zero_opt_dict = {
        "stage": stage,
        "offload_param": {"device": device},
        "offload_optimizer": {
            "device": "cpu" if adam_offload else "none",
            "pin_memory": True,
            "ratio": 0.6,
        },
        "sub_group_size": "auto",
        "stage3_max_live_parameters": "auto",
        "stage3_max_reuse_distance": "auto",
        "stage3_param_persistence_threshold": "auto",
        "stage3_prefetch_bucket_size": "auto",
        "reduce_bucket_size": "auto",
        # ZeRO++
        "zero_hpz_partition_size": zpg,
        "zero_quantized_weights": False,
        "zero_quantized_gradients": False,
    }
    
    if disable_trace_cache:
        zero_opt_dict["stage3_prefetch_bucket_size"] = 0
        zero_opt_dict["stage3_max_live_parameters"] = 0
        zero_opt_dict["stage3_max_reuse_distance"] = 0

    return {
        "steps_per_print": 100,
        "zero_optimization": zero_opt_dict,
        "bf16": {
            "enabled": bf16,
        },
        "gradient_clipping": max_norm,
        "prescale_gradients": False,
        "wall_clock_breakdown": False,
        "data_types": {"grad_accum_dtype": grad_accum_dtype if grad_accum_dtype else "fp32"},
        #"activation_checkpointing": {"partition_activations": True, "contiguous_memory_optimization": True, "cpu_checkpointing": True},
    }


def get_eval_ds_config(
    offload,
    stage=0,
    bf16=True,
):
    zero_opt_dict = {
        "stage": stage,
        "stage3_param_persistence_threshold": "auto",
        "offload_param": {
            "device": "cpu" if offload else "none",
            "pin_memory": True,
        },
    }
    return {
        "steps_per_print": 100,
        "zero_optimization": zero_opt_dict,
        "bf16": {
            "enabled": bf16,
        },
        "gradient_clipping": 1.0,
        "prescale_gradients": False,
        "wall_clock_breakdown": False,
    }


class Timer:
    def __init__(self, message):
        self.message = f"\033[38;5;208m{message}\033[0m"

    async def __aenter__(self):
        self.start_time = time.time()
        return self

    async def __aexit__(self, exc_type, exc_val, exc_tb):
        logger.opt(depth=1).info(f"{self.message}, time cost: {time.time() - self.start_time:.2f}s")


def _validate_args(args: DictConfig):
    assert args.zero_stage != 3 or args.vllm_num_engines > 0, "ZeRO-3 is only supported when vLLM enabled"
    assert not (
        args.reward_pretrain is None and not args.use_compute_reward_fn
    ), "at least one of reward model or custom reward fn should be set."

    assert (
        args.packing_max_len >= args.prompt_max_len + args.generate_max_len
    ), "packing_max_len should be set greater than prompt_max_len + generate_max_len when packing samples is True"
    assert (
        args.micro_forward_batch_size == 1 and args.micro_train_batch_size == 1
    ), "micro_forward_batch_size and micro_train_batch_size should be 1 when packing samples is True"


def masked_mean(tensor: torch.Tensor, mask: Optional[torch.Tensor], dim: int = None) -> torch.Tensor:
    if mask is None:
        return tensor.mean(axis=dim)
    return (tensor * mask).sum(axis=dim) / mask.sum(axis=dim)


@torch.no_grad()
def compute_approx_kl(
    log_probs: torch.Tensor,
    log_probs_base: torch.Tensor,
    action_mask: Optional[torch.Tensor] = None,
    use_kl_estimator_k3: bool = False,
    use_abs_kl: bool = False,
) -> torch.Tensor:
    """
    Compute the approximate KL divergence between two distributions.
    Schulman blog: http://joschu.net/blog/kl-approx.html

    Args:
        log_probs: Log probabilities of the new distribution.
        log_probs_base: Log probabilities of the base distribution.
        action_mask: Mask for actions.
    """

    log_ratio = log_probs - log_probs_base
    if action_mask is not None:
        log_ratio = log_ratio * action_mask

    # The k3 estimator is the non negative kl approximation in
    # http://joschu.net/blog/kl-approx.html
    # Besides non negative, it is also unbiased and have lower variance.
    if use_kl_estimator_k3:
        log_ratio = -log_ratio
        log_ratio = log_ratio.exp() - 1 - log_ratio

    if use_abs_kl:
        log_ratio = log_ratio.abs()

    return log_ratio


@ray.remote(num_cpus=1)
@torch.no_grad()
def compute_reward(
    r: Optional[Union[torch.Tensor, float]],
    kl_coef: float,
    kl: Union[torch.Tensor, list[torch.Tensor]],
    custom_rewards: Optional[Union[torch.Tensor, list[torch.Tensor]]] = None,
    action_mask: Optional[torch.Tensor] = None,
    num_actions: Optional[Union[int, list[int]]] = None,
    reward_clip_range: Tuple[float, float] = None,
    use_kl_loss: bool = False,
) -> Union[torch.Tensor, list[torch.Tensor]]:
    if kl_coef <= 0.0:
        kl_coef = 0.0

    if r is not None and reward_clip_range:
        r = r.clamp(min=reward_clip_range[0], max=reward_clip_range[1])
    
    if custom_rewards is not None:
        if action_mask is not None:
            custom_rewards = pad_sequence(custom_rewards, batch_first=True, padding_value=0.0)
        else:
            custom_rewards = torch.cat(custom_rewards, dim=0).unsqueeze(0)
    
    if kl is None:
        kl = torch.zeros_like(custom_rewards) if custom_rewards is not None else torch.zeros_like(r)

    if action_mask is not None:
        if not use_kl_loss:
            kl_reward = -kl_coef * kl
        else:
            kl_reward = torch.zeros_like(kl)
        # The following code is equivalent to:
        #
        # last_reward = torch.zeros_like(kl)
        # for i in range(last_reward.size(0)):
        #     for t in reversed(range(last_reward.size(1))):
        #         if action_mask[i][t] > 0.5:
        #             last_reward[i][t] = r[i]
        #             break
        #
        if r is not None:
            eos_indices = action_mask.size(1) - 1 - action_mask.long().fliplr().argmax(dim=1, keepdim=True)
            last_reward = torch.zeros_like(kl).scatter_(dim=1, index=eos_indices, src=r.unsqueeze(1).to(kl.dtype))
            reward = last_reward + kl_reward
        else:
            reward = kl_reward
        if custom_rewards is not None:            
            reward = reward + custom_rewards
    else:
        if not use_kl_loss:
            kl_reward = -kl_coef * kl
        else:
            kl_reward = torch.zeros_like(kl)
        if r is not None:
            kl_reward[:, torch.tensor(num_actions).cumsum(dim=-1) - 1] += r

        if custom_rewards is not None:
            reward = kl_reward + custom_rewards
        else:
            reward = kl_reward

    return reward


@ray.remote(num_cpus=1)
@torch.no_grad()
def get_advantages_and_returns(
    values: Optional[torch.Tensor],
    rewards: torch.Tensor,
    action_mask: torch.Tensor,
    num_actions: Optional[torch.Tensor],
    gamma: Union[float, torch.Tensor],
    lambd: float,
    packing: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Function that computes advantages and returns from rewards and values.
    Calculated as in the original PPO paper: https://arxiv.org/abs/1707.06347
    Note that rewards may include a KL divergence loss term.

    Advantages looks like this:
    Adv1 =  R1 + γ * λ * R2     + γ^2 * λ^2 * R3       + ...
            - V1 + γ * (1 - λ) V2 + γ^2 * λ * (1 - λ) V3 + ...

    Returns looks like this:
    Ret1 =  R1 + γ * λ * R2     + γ^2 * λ^2 * R3       + ...
                + γ * (1 - λ) V2 + γ^2 * λ * (1 - λ) V3 + ...

    Input:
    - values: Tensor of shape (batch_size, response_size)
    - rewards: Tensor of shape (batch_size, response_size)

    Output:
    - advantages: Tensor of shape (batch_size, response_size)
    - returns: Tensor of shape (batch_size, response_size)
    """

    # fp32 mode
    if values is not None:
        values = values.float()
    rewards = rewards.float()

    if packing:
        accum_reverse_num_actions = torch.cumsum(torch.tensor(num_actions), dim=0)
        sample_idx = len(num_actions) - 1
    else:
        accum_reverse_num_actions = None

    lastgaelam = 0
    advantages_reversed = []
    response_length = rewards.size(1)

    # Mask invalid responses
    if action_mask is not None:
        if values is not None:
            values = action_mask * values
        rewards = action_mask * rewards

    for t in reversed(range(response_length)):
        if values is not None:
            nextvalues = values[:, t + 1] if t < response_length - 1 else 0.0
        else:
            nextvalues = 0.0
        if packing and sample_idx >= 0 and t + 1 == accum_reverse_num_actions[sample_idx]:
            sample_idx -= 1
            lastgaelam = 0
            nextvalues = 0.0
        if isinstance(gamma, torch.Tensor):
            gamma_ = gamma[:, t]
        else:
            gamma_ = gamma
        if values is not None:
            delta = rewards[:, t] + gamma_ * nextvalues - values[:, t]
        else:
            delta = rewards[:, t]
        lastgaelam = delta + gamma_ * lambd * lastgaelam
        advantages_reversed.append(lastgaelam)
    advantages = torch.stack(advantages_reversed[::-1], dim=1)
    if values is not None:
        returns = advantages + values
    else:
        returns = advantages
    return advantages.detach(), returns


def normalize_advantages(buffer):
    items = []
    action_masks = []
    sys_masks = []
    for item in buffer:
        items.append(getattr(item, "advantages"))
        if item.action_mask is not None:
            action_mask = item.action_mask
        else:
            action_mask = torch.ones(item.advantages.shape, dtype=torch.bool, device=item.advantages.device)
        action_masks.append(action_mask)
        sys_masks.append(item.sys_mask)

    items_vector = torch.cat(items).float().flatten()
    action_masks_vector = torch.cat(action_masks).flatten()
    num_actions = action_masks_vector.sum()
    
    if sys_masks[0] is not None:
        sys_masks_vector = torch.cat(sys_masks).flatten()
        action_masks_vector = torch.logical_and(action_masks_vector, torch.logical_not(sys_masks_vector))
        num_actions = action_masks_vector.sum()

    # mean
    mean = items_vector.mean()
    # std
    std = ((items_vector - mean).pow(2) * action_masks_vector).sum()
    rstd = (std / num_actions).clamp(min=1e-8).rsqrt()

    for i, item in enumerate(buffer):
        t = (items[i] - mean) * rstd
        setattr(item, "advantages", t.bfloat16())
    return buffer


class ORZDeepspeedStrategy(DeepspeedStrategy):
    def get_ds_train_config(self, is_actor):
        # DS Config
        ds_config = get_train_ds_config(
            offload=self.param_offload,
            adam_offload=self.adam_offload,
            stage=self.stage,
            bf16=self.bf16,
            max_norm=self.max_norm,
            zpg=self.zpg,
            grad_accum_dtype=self.grad_accum_dtype,
            disable_trace_cache=self.disable_trace_cache,
        )

        ds_config["train_micro_batch_size_per_gpu"] = self.micro_train_batch_size
        ds_config["gradient_accumulation_steps"] = 1

        return ds_config

    def get_ds_eval_config(self, offload=False):
        # DS Config
        ds_config = get_eval_ds_config(offload=offload, stage=self.stage if self.stage == 3 else 0, bf16=self.bf16)
        ds_config["train_micro_batch_size_per_gpu"] = self.micro_train_batch_size
        ds_config["gradient_accumulation_steps"] = 1

        return ds_config


def get_strategy(args):
    strategy = ORZDeepspeedStrategy(
        seed=getattr(args, "seed", 42),
        max_norm=getattr(args, "max_norm", 1.0),
        micro_train_batch_size=getattr(args, "micro_train_batch_size", 1),
        train_batch_size=getattr(args, "train_batch_size", 128),
        zero_stage=args.zero_stage,
        bf16=getattr(args, "bf16", True),
        args=args,
    )
    return strategy


def create_vllm_engines(
    num_engines: int,
    tensor_parallel_size: int,
    pretrain: str,
    seed: int,    
    summary: bool,
    multi_attempt: bool,
    colocate_with_actor: bool,
    colocate_pg: Optional[PlacementGroup] = None,    
    **kwargs,
):
    kwargs["enable_chunked_prefill"] = kwargs.get("enable_chunked_prefill", False)
    kwargs["max_num_batched_tokens"] = kwargs.get("max_num_batched_tokens", 2048) if kwargs["enable_chunked_prefill"] else None
    kwargs["gpu_memory_utilization"] = kwargs.get("gpu_memory_utilization", 0.85)
    kwargs["max_num_seqs"] = kwargs.get("max_num_seqs", 256)
    
    if multi_attempt:
        Actor = MulLLMRayActor
    elif summary:
        Actor = SumLLMRayActor
    else:
        Actor = LLMRayActor 

    vllm_engines = []
    if tensor_parallel_size > 1:
        assert not colocate_with_actor, "colocate_with_actor is not supported when tensor_parallel_size > 1"
        num_gpus = 0
        for i in range(num_engines):
            bundles = [{"GPU": 1, "CPU": 8}] * tensor_parallel_size
            pg = placement_group(bundles, strategy="PACK")
            ray.get(pg.ready())

            scheduling_strategy = PlacementGroupSchedulingStrategy(
                placement_group=pg, placement_group_capture_child_tasks=True, placement_group_bundle_index=0
            )
            vllm_engines.append(
                Actor.options(num_cpus=8, num_gpus=num_gpus, scheduling_strategy=scheduling_strategy,).remote(
                    pretrain,
                    trust_remote_code=True,
                    tensor_parallel_size=tensor_parallel_size,
                    dtype="bfloat16",
                    seed=seed + i,
                    block_size=256,
                    **kwargs
                )
            )
    else:
        if not colocate_with_actor:
            num_gpus = 1
            num_cpus = 1
            bundles = [{"GPU": 1, "CPU": 1}] * num_engines
            pg = placement_group(bundles, strategy="PACK")
            ray.get(pg.ready())
        else:
            num_gpus = 0.2
            num_cpus = 0.2
            assert colocate_pg is not None, "colocate_pg must be provided when colocate_with_actor is True"

        for i in range(num_engines):
            scheduling_strategy = PlacementGroupSchedulingStrategy(
                placement_group=colocate_pg if colocate_with_actor else pg,
                placement_group_capture_child_tasks=True,
                placement_group_bundle_index=i,
            )
            vllm_engines.append(
                Actor.options(
                    num_cpus=num_cpus,
                    num_gpus=num_gpus,
                    scheduling_strategy=scheduling_strategy,
                ).remote(
                    pretrain,
                    trust_remote_code=True,
                    tensor_parallel_size=tensor_parallel_size,
                    dtype="bfloat16",
                    seed=seed + i,
                    block_size=256,
                    **kwargs,
                )
            )
        if colocate_with_actor:
            offload_refs = []
            for llm in vllm_engines:
                offload_refs.append(llm.offload_to_cpu.remote())
            ray.get(offload_refs)
            logger.info("Offloaded all vLLM engines to CPU")

    return vllm_engines


# reflection pattern checking related

# check how many reflection pattern related words are in the responses
def check_reflection_pattern(response: str) -> dict[str, int]:
    # TODO: may need to add more pattern
    reflection_pattern_words = [
        r"wait,",
        r"recheck[,\s]",
        r"retry",
        r"alternatively,",
        r"however,",
    ]
    res = defaultdict(int)
    for word in reflection_pattern_words:
        # can only be followed by a comma or a space
        res[word] = len(re.findall(word, response))
    return res

def create_token_mask(input_tokens, tokenizer, start_token="<|im_end|>", end_token="<|im_start|>Assistant: <think>"):
    """
    Creates a mask where tokens between and including `start_token` and `end_token` are set to 1, else 0.
    Supports `start_token` and `end_token` that tokenize into multiple token IDs.

    Args:
        input_tokens (List[int]): Tokenized input (list of token IDs).
        tokenizer: The tokenizer used for tokenizing the input.
        start_token (str): Special start token that may tokenize into multiple token IDs.
        end_token (str): Special end token that may tokenize into multiple token IDs.

    Returns:
        List[int]: A binary mask of the same length as input_tokens.
    """
    if isinstance(input_tokens, torch.Tensor):
        input_tokens = input_tokens.long().tolist() 
        
    mask = [0] * len(input_tokens)
    inside_special = False  # Flag to track masking state

    # Tokenize the start and end tokens to get their corresponding token ID sequences
    if start_token is not None:
        start_token_ids = tokenizer.encode(start_token, add_special_tokens=False)
        if len(start_token_ids) > 1:
            start_token_ids = start_token_ids[:-1]
            start_blur_n = 1 # blur search for the last token, as sometimes subsequent string may change token
        else:
            start_blur_n = 0
    else:
        inside_special = True

    if end_token is not None:
        end_token_ids = tokenizer.encode(end_token, add_special_tokens=False)
        if len(end_token_ids) > 1:
            end_token_ids = end_token_ids[:-1]
            end_blur_n = 1 # blur search for the last token, as sometimes subsequent string may change token
        else:
            end_blur_n = 0

    i = 0
    while i < len(input_tokens):
        if start_token is not None:
            # Check for the start token sequence
            if input_tokens[i:i+len(start_token_ids)] == start_token_ids:
                inside_special = True
                mask[i:i+len(start_token_ids)+start_blur_n] = [1] * (len(start_token_ids) + start_blur_n)
                i += len(start_token_ids) + start_blur_n
                if i >= len(input_tokens):
                    break  # Exit the loop if we've reached the end

        # Apply mask while inside the special region
        if inside_special:
            mask[i] = 1
            # Check for the end token sequence
            if end_token is not None and input_tokens[i:i+len(end_token_ids)] == end_token_ids:
                inside_special = False
                mask[i:i+len(end_token_ids)+end_blur_n] = [1] * (len(end_token_ids) + end_blur_n)
                i += len(end_token_ids) + end_blur_n - 1  # Move pointer to the end of end_token_ids
        
        i += 1  # Move to the next token
    mask = mask[:len(input_tokens)] # sometimes overflow if start token or end token are at the end with blurry search
    return mask

def packed_create_token_mask(sequences, tokenizer, num_actions, packed_seq_lens, start_token="<|im_end|>", end_token="<|im_start|>Assistant: <think>", non_action_value=None):
    c_idx = 0
    sequence = sequences[0].cpu().detach()
    mask = []
    for num_action, seq_len in zip(num_actions, packed_seq_lens):
        mask_ = create_token_mask(sequence[c_idx+seq_len-num_action: c_idx+seq_len], tokenizer, start_token, end_token)
        if non_action_value is not None:
            mask_ = ([non_action_value] * (seq_len - num_action)) + mask_
        mask.extend(mask_)
        c_idx = c_idx + seq_len
    mask = torch.tensor(mask, device=sequences[0].device, dtype=torch.bool).unsqueeze(0)
    return mask

def join_ls_str(x):
    return list(["".join(y) for y in x]) if isinstance(x[0], list) else x

def join_str(x):
    return "".join(x) if isinstance(x, list) else x

def save_debug_data(directory="large_data/tmp", prefix="debug", max_file=-1, **kwargs):
    """
    Save arbitrary keyword arguments as a dictionary in a unique pickle file.

    Args:
        directory (str): Directory where the file will be saved.
        prefix (str): Filename prefix (default is "debug").
        **kwargs: Any keyword arguments to save.

    Returns:
        str: The filename where the data was saved.
    """
    os.makedirs(directory, exist_ok=True)  # Ensure the directory exists

    # Find the smallest available filename
    n = 0
    while os.path.exists(f"{directory}/{prefix}_{n}.pickle"):
        n += 1
    if max_file > 0 and n >= max_file:
        return None
    filename = f"{directory}/{prefix}_{n}.pickle"

    # Convert tensors and NumPy arrays to CPU
    def process_data(value):
        if isinstance(value, torch.Tensor):
            return value.detach().cpu()  # Ensure the tensor is detached and moved to CPU
        elif isinstance(value, np.ndarray):
            return torch.from_numpy(value)  # Convert NumPy arrays to tensors
        elif isinstance(value, list):
            return [process_data(v) for v in value]  # Process each item in the list
        elif isinstance(value, dict):
            return {k: process_data(v) for k, v in value.items()}  # Process each key-value pair
        else:
            return value  # Keep other objects as-is

    processed_data = {key: process_data(value) for key, value in kwargs.items()}    
    torch.save(processed_data, filename)

    print(f"Saved debug data to {filename}")
    return filename  # Return the saved filename            

def load_debug_data(directory="large_data/tmp", prefix="debug", max_n=-1):
    merged_data = {}
    # Get all matching files
    files = sorted([f for f in os.listdir(directory) if f.startswith(prefix) and f.endswith(".pickle")])
    n = 0

    for file in files:
        file_path = os.path.join(directory, file)        
        # Load the pickle file
        data = torch.load(file_path)

        # Merge into the large dictionary
        for key, value in data.items():
            if key in merged_data:
                merged_data[key].append(value)  # Append non-tensors
            else:
                merged_data[key] = [value]
        n += 1
        if n >= max_n and max_n > 0:
            break
    print(f"Loaded {len(files)} files from {directory} with prefix '{prefix}'")
    return merged_data

def cum_clip(x, min_val=-1, max_val=1):
    """
    Compute the cumulative sum of tensor x along its only dimension.
    If the cumulative sum exceeds max_val or falls below min_val at any index,
    then that index is set to the exact amount needed to reach the limit,
    and all subsequent elements are set to zero.
    
    Args:
        x (torch.Tensor): A 1D tensor.
        min_val (float): Lower bound (default: -1).
        max_val (float): Upper bound (default: 1).
        
    Returns:
        torch.Tensor: The tensor after applying the clipping rule.
    """
    # Compute cumulative sum over the tensor.
    cumsum_x = torch.cumsum(x, dim=0)
    
    # Create a boolean mask where the cumulative sum violates the limits.
    violation = (cumsum_x > max_val) | (cumsum_x < min_val)
    
    # If no violation occurs, return the original tensor.
    if not violation.any():
        return x.clone()
    
    # Find the first index where the violation occurs.
    idx = violation.nonzero(as_tuple=False)[0].item()
    
    # Get the cumulative sum just before the violation.
    prev_sum = cumsum_x[idx - 1] if idx > 0 else torch.tensor(0., device=x.device)
    
    # Compute the adjusted value to exactly reach the limit.
    new_value = (max_val - prev_sum) if cumsum_x[idx] > max_val else (min_val - prev_sum)    
    out = torch.cat([x[:idx], new_value.unsqueeze(0), torch.zeros_like(x[idx+1:])])    
    return out

def filter_seq_lens_num_actions(mask, packed_seq_lens, num_actions):
    assert len(mask.shape) == 2
    assert len(packed_seq_lens) == len(num_actions)
    
    f_packed_seq_lens, f_num_actions = [], []
    index = 0
    for num_action, seq_len in zip(num_actions, packed_seq_lens):
        mask_ = mask[..., index: index + seq_len]
        f_packed_seq_lens.append(mask_.sum().item())
        f_num_actions.append(mask_.sum().item() - (seq_len - num_action))
        index += seq_len
    
    return f_packed_seq_lens, f_num_actions

def get_physical_gpu_id():
    import torch

    device = torch.cuda.current_device()
    props = torch.cuda.get_device_properties(device)
    return str(props.uuid)