"""
This utility file provides a centralized function to reconstruct human-readable
trajectories from a DataProto batch, ensuring consistent data representation
for different reward marking algorithms.
"""
import torch
import random
import numpy as np
from collections import defaultdict
from verl import DataProto
from transformers import PreTrainedTokenizer

import json

def reconstruct_readable_traces(batch: DataProto, tokenizer) -> dict:
    """
    Reconstructs trajectories from a DataProto batch into a structured,
    human-readable format of (Observation, Action) pairs, ensuring correct
    chronological order by parsing step numbers from prompts.
    """
    output_dialogues = {}
    
    # --- PERFORMANCE OPTIMIZATION ---
    # Pre-check existence of invalid action tags for the entire batch.
    has_valid_tags = 'is_action_valid' in batch.non_tensor_batch
    action_valid_array = batch.non_tensor_batch.get('is_action_valid')

    # Batch Decode ALL prompts and responses once (Much faster and consistent)
    # This avoids repeated decode calls inside loops and potential truncation issues
    all_prompts_str = tokenizer.batch_decode(batch.batch['prompts'], skip_special_tokens=True)
    all_responses_str = tokenizer.batch_decode(batch.batch['responses'], skip_special_tokens=True)

    # 1. Group steps by traj_uid, storing (step_number, batch_index) tuples
    traj_to_steps = defaultdict(list)
    # Find "You are now at step X" using string methods
    # step_num_regex = re.compile(r"You are now at step (\d+)")

    for i in range(len(batch)):
        item = batch[i]
        traj_uid = item.non_tensor_batch['traj_uid']
        prompt_str = all_prompts_str[i]
        
        step_prefix = "You are now at step "
        step_suffix = " "

        step_num_str = None
        start_idx = prompt_str.find(step_prefix)
        if start_idx != -1:
            end_idx = prompt_str.find(step_suffix, start_idx + len(step_prefix))
            if end_idx != -1:
                potential_step_num_str = prompt_str[start_idx + len(step_prefix):end_idx].strip()
                if potential_step_num_str.isdigit():
                    step_num_str = potential_step_num_str
        
        if step_num_str:
            step_num = int(step_num_str)
        # Handle the first step which might not have the "You are now at step..." pattern
        elif "Prior to this step" not in prompt_str:
            step_num = 1
        else:
            # Fallback for unforeseen formats, though less reliable
            print("ERROR: wrong step number")
            step_num = len(traj_to_steps.get(traj_uid, [])) + 1
        traj_to_steps[traj_uid].append((step_num, i))

    # 2. Reconstruct each trajectory
    for traj_uid, steps in traj_to_steps.items():
        # Sort by the parsed step number to guarantee chronological order
        steps.sort()
        
        # Extract the sorted batch indices
        indices = [batch_index for step_num, batch_index in steps]
        
        reconstructed_trace = []
        for step_index in indices:
            prompt_str = all_prompts_str[step_index]
            response_str = all_responses_str[step_index]

            observation = "OBS_NOT_FOUND"
            obs_start_tag = "is:"
            obs_end_tag = "\nYour"

            obs_start = prompt_str.find(obs_start_tag)
            if obs_start != -1:
                obs_end =  prompt_str.find(obs_end_tag, obs_start + len(obs_start_tag))
                if obs_end != -1:
                    observation = prompt_str[obs_start + len(obs_start_tag):obs_end].strip()
            if observation == "OBS_NOT_FOUND":
                print("ERROR: wrong obs")
            # --- ACTION SHORTCUT & ALIGNMENT ---
            # Default to "nothing" if parsing fails or environment flags the action as invalid.
            action = "nothing"
            
            is_valid = True
            if has_valid_tags:
                val = action_valid_array[step_index]
                is_valid = bool(val.item()) if hasattr(val, 'item') else bool(val)
            
            if is_valid:
                action_start_tag = "<action>"
                action_end_tag = "</action>"
                
                action_start = response_str.find(action_start_tag)
                if action_start != -1:
                    action_end = response_str.find(action_end_tag, action_start + len(action_start_tag))
                    if action_end != -1:
                        action = response_str[action_start + len(action_start_tag):action_end].strip()

            reconstructed_trace.append({
                "batch_index": step_index,
                "observation": observation,
                "action": action
            })

        # 3. Determine success status and goal from the trajectory
        last_step_item = batch[indices[-1]]
        is_successful = bool(last_step_item.non_tensor_batch['episode_rewards'] > 0)
        
        # Try to get goal from non_tensor_batch first (if populated by loader)
        # We check the first step's item for consistency
        first_step_item = batch[indices[0]]
        goal = None
        if 'goal' in first_step_item.non_tensor_batch:
            # It might be a numpy array wrapping a string/object
            raw_goal = first_step_item.non_tensor_batch['goal']
            # Handle potential array wrapping
            goal = str(raw_goal.item()) if hasattr(raw_goal, 'item') else str(raw_goal)

        # Fallback to regex parsing if goal is missing or empty
        if not goal or goal == "None":
            first_step_prompt = all_prompts_str[indices[0]]
            goal = "Unknown Goal"
            goal_prefix = "Your task is to: "
            goal_suffix = "\n"

            goal_start = first_step_prompt.find(goal_prefix)
            if goal_start != -1:
                search_start = goal_start + len(goal_prefix)
                goal_end = first_step_prompt.find(goal_suffix, search_start)
                if goal_end != -1:
                    goal = first_step_prompt[search_start:goal_end].strip()
                else: # if no newline after goal, take till end of string
                    goal = first_step_prompt[search_start:].strip()

        output_dialogues[traj_uid] = {
            "uid": traj_uid,
            "success": is_successful,
            "goal": goal,
            "trace": reconstructed_trace
        }
    
    return output_dialogues
##################################################
########      sample data functions       ########
##################################################
def sample_and_extract_sft_pair(sft_data_dir: str="/workspace/Code/verl-agent/trpo_dataset/json/alfworld_sft_data_json") -> tuple[str, str]:
    """
    Randomly samples a QA pair from the SFT dataset and extracts the
    observation and action.
    """
    import random
    from pathlib import Path
    
    data_path = Path(sft_data_dir)
    batch_files = list(data_path.glob("sft_batch_*.jsonl"))

    if not batch_files:
        print(f"Error: No 'sft_batch_*.jsonl' files found in {sft_data_dir}.")
        return None, None, None

    while True:
        random_file = random.choice(batch_files)
        try:
            with open(random_file, 'r') as f:
                lines = [line for line in f if line.strip()]
            if lines:
                random_line = random.choice(lines)
                break
            else:
                print(f"Warning: The file {random_file} is empty. Selecting another file...")
                continue
        except Exception as e:
            print(f"Error reading file {random_file}: {e}")
            continue
    
    try:
        qa_pair = json.loads(random_line)
        prompt_str = qa_pair.get('user', '')
        assistant_str = qa_pair.get('assistant', '')

        observation = "OBS_NOT_FOUND"
        # Option 1: "Looking quickly around you,"
        # Option 2: "your current observation is:"
        # Ends with "Your task is to" or "Your admissible actions"
        
        obs_prefix1 = "Looking quickly around you,"
        obs_prefix2 = "your current observation is:"
        
        task_suffix = "Your task is to"
        actions_suffix = "Your admissible actions"
        
        obs_start_idx = -1
        prefix_len = 0

        # Try prefix 2 first as it's more explicit
        start2 = prompt_str.find(obs_prefix2)
        if start2 != -1:
            obs_start_idx = start2 + len(obs_prefix2)
            prefix_len = len(obs_prefix2)
        else: # Fallback to prefix 1
            start1 = prompt_str.find(obs_prefix1)
            if start1 != -1:
                obs_start_idx = start1 + len(obs_prefix1)
                prefix_len = len(obs_prefix1)

        if obs_start_idx != -1:
            search_start = obs_start_idx
            
            task_end_idx = prompt_str.find(task_suffix, search_start)
            actions_end_idx = prompt_str.find(actions_suffix, search_start)
            
            obs_end_idx = -1
            if task_end_idx != -1 and actions_end_idx != -1:
                obs_end_idx = min(task_end_idx, actions_end_idx)
            elif task_end_idx != -1:
                obs_end_idx = task_end_idx
            elif actions_end_idx != -1:
                obs_end_idx = actions_end_idx

            if obs_end_idx != -1:
                observation = prompt_str[obs_start_idx:obs_end_idx].strip()
            else:
                observation = prompt_str[obs_start_idx:].strip()

        action = "ACTION_NOT_FOUND"
        action_start_tag = "<action>"
        action_end_tag = "</action>"
        
        action_start = assistant_str.find(action_start_tag)
        if action_start != -1:
            action_end = assistant_str.find(action_end_tag, action_start + len(action_start_tag))
            if action_end != -1:
                action = assistant_str[action_start + len(action_start_tag):action_end].strip()

        goal = "GOAL_NOT_FOUND"
        goal_prefix = "Your task is to: "
        goal_suffix = "\n"

        goal_start = prompt_str.find(goal_prefix)
        if goal_start != -1:
            search_start = goal_start + len(goal_prefix)
            goal_end = prompt_str.find(goal_suffix, search_start)
            if goal_end != -1:
                goal = prompt_str[search_start:goal_end].strip()
            else: # if no newline after goal, take till end of string
                goal = prompt_str[search_start:].strip()
        
        return observation, action, goal

    except (json.JSONDecodeError, KeyError) as e:
        print(f"Error processing data from line: {random_line}. Error: {e}")
        return None, None, None

from agent_system.environments.prompts.alfworld import ALFWORLD_TEMPLATE, ALFWORLD_TEMPLATE_NO_HIS
##################################################
########   convert to dataproto functions    #####
##################################################
def convert_dialogue_to_dataproto(raw_dialogue, tokenizer, ref_batch):
    """
    Converts a raw dialogue (list of turn dicts) into a DataProto object.
    Now optimized for standard <think>...</think><action>...</action> format.
    """
    # 1. Parse initial turn
    initial_turn_content = None
    try:
        initial_turn_content = json.loads(raw_dialogue[0]['content'])
    except json.JSONDecodeError:
        pass 

    if initial_turn_content:
        initial_full_obs_str = initial_turn_content.get('observation', "")
        initial_admissible_cmds = initial_turn_content.get('admissible_commands', [])
    else:
        initial_full_obs_str = raw_dialogue[0]['content'] 
        initial_admissible_cmds = [] 

    task_description = ""
    if 'Your task is to:' in initial_full_obs_str:
        task_description = initial_full_obs_str.split('Your task is to:')[1].split('\n\n')[0].strip()

    # 2. Loop through steps
    history_tuples = []
    batch_items = []
    
    current_observation_for_action = initial_full_obs_str
    current_admissible_cmds_for_action = initial_admissible_cmds

    group_uid = ref_batch.non_tensor_batch['uid'][0] if ref_batch else "group_uid_" + str(random.randint(1000, 9999))
    expert_traj_uid = "expert_traj_uid_" + str(random.randint(1000, 9999))

    for i in range(0, len(raw_dialogue), 2): 
        if i + 1 >= len(raw_dialogue):
            break
            
        # The assistant response is now assumed to be in standard format (think + action tags)
        response_str_for_encoding = raw_dialogue[i+1]['content']
        
        # Extract clean action for history (stripping tags and think block)
        clean_action = response_str_for_encoding.strip() # Default fallback
        action_start_tag = "<action>"
        action_end_tag = "</action>"
        
        action_start = response_str_for_encoding.find(action_start_tag)
        if action_start != -1:
            action_end = response_str_for_encoding.find(action_end_tag, action_start + len(action_start_tag))
            if action_end != -1:
                clean_action = response_str_for_encoding[action_start + len(action_start_tag):action_end].strip()

        # Build User Prompt
        step_count = len(history_tuples) 
        history_length = 4 
        
        if step_count == 0:
            user_prompt = ALFWORLD_TEMPLATE_NO_HIS.format(
                current_observation=current_observation_for_action,
                admissible_actions=', '.join(f"'{a}'" for a in current_admissible_cmds_for_action)
            )
        else:
            recent_history = history_tuples[-history_length:]
            history_str = "\n".join([f"[Observation {step_count - len(recent_history) + i + 1}: '{obs}', Action {step_count - len(recent_history) + i + 1}: '{act}']" for i, (obs, act) in enumerate(recent_history)])
            user_prompt = ALFWORLD_TEMPLATE.format(
                current_observation=current_observation_for_action,
                admissible_actions=', '.join(f"'{a}'" for a in current_admissible_cmds_for_action),
                task_description=task_description,
                step_count=step_count,
                history_length=len(recent_history),
                action_history=history_str,
                current_step=step_count + 1
            )

        # --- Tensor Structure ---
        max_prompt_length = ref_batch.batch['prompts'].shape[1] if ref_batch else 1024
        max_response_length = ref_batch.batch['responses'].shape[1] if ref_batch else 2048
        dtype_prompts = ref_batch.batch['prompts'].dtype if ref_batch else torch.int64
        dtype_responses = ref_batch.batch['responses'].dtype if ref_batch else torch.int64

        expert_prompt_ids = tokenizer.encode(user_prompt, add_special_tokens=False)
        expert_response_ids = tokenizer.encode(response_str_for_encoding, add_special_tokens=False)

        padded_prompt_ids = expert_prompt_ids[:max_prompt_length]
        padded_prompt_ids.extend([tokenizer.pad_token_id] * (max_prompt_length - len(padded_prompt_ids)))
        
        padded_response_ids = expert_response_ids[:max_response_length]
        padded_response_ids.extend([tokenizer.pad_token_id] * (max_response_length - len(padded_response_ids)))

        prompt_tensor = torch.tensor(padded_prompt_ids, dtype=dtype_prompts)
        response_tensor = torch.tensor(padded_response_ids, dtype=dtype_responses)

        # Ensure goal is passed down if known (e.g. from convert logic if extended, or leave blank)
        # For generated data, we rely on task description parsing inside this func if needed, 
        # but here we mainly focus on load_dialogues_from_jsonl for explicit goal passing.
        goal_val = task_description if task_description else "Unknown Goal"

        single_sample_data = {
            "prompts": prompt_tensor.unsqueeze(0),
            "responses": response_tensor.unsqueeze(0),
            "episode_rewards": np.array([1]),
            "uid": np.array([group_uid], dtype=object),
            "traj_uid": np.array([expert_traj_uid], dtype=object),
            "goal": np.array([goal_val], dtype=object) # Add goal here too
        }

        dp = DataProto.from_single_dict(single_sample_data)
        batch_items.append(dp)

        history_tuples.append((current_observation_for_action, clean_action))
        
        if i + 2 < len(raw_dialogue):
             next_turn = raw_dialogue[i+2]
             # Parse next observation... (logic depends on input format, assumed handled implicitly)
             pass

    if not batch_items:
        return None 

    expert_batch = DataProto.concat(batch_items)
    return expert_batch


async def load_dialogues_from_jsonl(
    jsonl_path: str,
    tokenizer: PreTrainedTokenizer,
    ref_batch: DataProto # for shape/dtype reference
) -> DataProto:
    """
    Loads dialogue data from a JSONL file (as generated by run_trace.jsonl).
    Constructs DataProto objects directly from the log entries, treating each turn as a pre-formatted datapoint.
    Assigns a unique traj_uid to each dialogue.
    """
    # print(f"\n--- Loading Dialogues from JSONL: {jsonl_path} ---")
    all_dataprotos = []
    
    # Defaults for shapes/dtypes if ref_batch is missing
    max_prompt_length = ref_batch.batch['prompts'].shape[1] if ref_batch else 5632
    max_response_length = ref_batch.batch['responses'].shape[1] if ref_batch else 2048
    dtype_prompts = ref_batch.batch['prompts'].dtype if ref_batch else torch.int64
    dtype_responses = ref_batch.batch['responses'].dtype if ref_batch else torch.int64
    
    with open(jsonl_path, 'r', encoding='utf-8') as f:
        for line_idx, line in enumerate(f):
            trace_entry = json.loads(line)
            raw_dialogue = trace_entry['dialogue']
            
            # Generate a unique Trajectory UID for this entire dialogue
            traj_uid = f"log_traj_{line_idx}_{random.randint(10000, 99999)}"
            
            # Determine success/reward from the top-level 'success' field
            is_success = trace_entry.get('success', False)
            episode_reward = 10.0 if is_success else 0.0

            # Determine Goal from 'task_name' or prompt
            goal = trace_entry.get('task_name', "").replace("_", " ").strip()
            
            for i in range(0, len(raw_dialogue), 2):
                if i + 1 >= len(raw_dialogue):
                    break
                
                user_content = raw_dialogue[i]['content'].replace('\\n', '\n')
                assistant_content = raw_dialogue[i+1]['content'].replace('\\n', '\n')
                
                # --- Tensor Construction ---
                prompt_ids = tokenizer.encode(user_content, add_special_tokens=False)
                response_ids = tokenizer.encode(assistant_content, add_special_tokens=False)

                padded_prompt_ids = prompt_ids[:max_prompt_length]
                padded_prompt_ids.extend([tokenizer.pad_token_id] * (max_prompt_length - len(padded_prompt_ids)))
                
                padded_response_ids = response_ids[:max_response_length]
                padded_response_ids.extend([tokenizer.pad_token_id] * (max_response_length - len(padded_response_ids)))

                prompt_tensor = torch.tensor(padded_prompt_ids, dtype=dtype_prompts)
                response_tensor = torch.tensor(padded_response_ids, dtype=dtype_responses)
                
                # --- Token Level Scores Construction ---
                # Initialize with zeros
                token_level_scores = torch.zeros_like(response_tensor, dtype=torch.float32)
                # Set episode reward at the last valid token
                if len(response_ids) > 0:
                    last_token_idx = len(response_ids) - 1
                    token_level_scores[last_token_idx] = episode_reward
                
                # We use a placeholder group UID here; the caller (idea_test.py) handles unifying them if needed.
                group_uid = "log_group_placeholder" 

                single_sample_data = {
                    "prompts": prompt_tensor.unsqueeze(0),
                    "responses": response_tensor.unsqueeze(0),
                    "token_level_scores": token_level_scores.unsqueeze(0), # Add this field
                    "episode_rewards": np.array([episode_reward]),
                    "uid": np.array([group_uid], dtype=object),
                    "traj_uid": np.array([traj_uid], dtype=object),
                    "goal": np.array([goal], dtype=object) # Store parsed goal
                }
                
                dp = DataProto.from_single_dict(single_sample_data)
                all_dataprotos.append(dp)
    
    if not all_dataprotos:
        raise ValueError(f"No valid dataprotos could be created from {jsonl_path}")
        
    loaded_batch = DataProto.concat(all_dataprotos)
    # print(f"Successfully loaded {len(loaded_batch)} steps from {jsonl_path}.")
    return loaded_batch


async def generate_expert_datapoint(
    expert_trial_path: str, 
    tokenizer: PreTrainedTokenizer, 
    ref_batch: DataProto
) -> DataProto:
    """
    Generates a datapoint for a single expert trajectory.
    """
    print("\n--- Generating Expert Datapoint ---")
    if not ray.is_initialized():
        ray.init(ignore_reinit_error=True)

    manager = AlfworldRayManager.get_instance(config={"num_expert_workers": 4})
    raw_dialogue = await manager.generate_sft_dialogue(expert_trial_path, format_as_tool_code=False)
    
    if not raw_dialogue:
        raise ValueError(f"Failed to generate dialogue from expert trial: {expert_trial_path}")

    expert_batch = convert_dialogue_to_dataproto(raw_dialogue, tokenizer, ref_batch)
    
    if expert_batch is None:
        raise ValueError("Failed to create any datapoints from the expert trajectory.")

    print(f"Expert trajectory successfully converted into a DataProto batch of {len(expert_batch)} steps.")
    return expert_batch
