from dotenv import load_dotenv
import os
import google.generativeai as genai
from PIL import Image
from pathlib import Path
import random
import json
import re
from datetime import datetime
import time

load_dotenv()
api_key = os.getenv('GOOGLE_API_KEY')
if not api_key:
    print("Error: GOOGLE_API_KEY not found in environment variables.")
    exit()

genai.configure(api_key=api_key)

# --- Define Path for the Dedicated Few-Shot Example ---
# Adjust this path to where your 'few_shot_example' directory is relative to the script
# Assuming it's at the same level as your 'data_subset' or similar
FEW_SHOT_EXAMPLE_BASE_DIR = Path("few_shot_example") # ADJUST IF NEEDED

# --- Main Data Directory ---

#base_dir = Path("../../eval_ood/minsky_folding_table_white_tray_sweep_granular_100") # ADJUST
#base_dir = Path("../../eval_ood/script_pnp_objects_200") # ADJUST
#base_dir = Path("../../eval_ood/bridge_many_skills_exp") 

tk_mulit = Path("../../eval_ood/bridge_many_skills_exp") 
in_dist = [tk_mulit]

td_fold = Path("../../eval_ood/tabletop_dark_wood_fold_cloth_pnp_200")
mn_ft_sweep = Path("../../eval_ood/minsky_folding_table_white_tray_sweep_granular_100") # ADJUST
ft_fold = Path("../../eval_ood/folding_table_fold_cloth_pnp_200")
lm_pnp = Path("../../eval_ood/laundry_machine_pnp_sweep_200")
rb_fold = Path("../../eval_ood/robot_desk_fold_cloth_200")



env_ood = [td_fold,mn_ft_sweep, ft_fold, lm_pnp, rb_fold]

sc_sweep = Path("../../eval_ood/script_sweep_pnp_200")
sc_obj = Path("../../eval_ood/script_pnp_objects_200")
sc_robj = Path("../../eval_ood/script_pnp_rigid_objects_200")
sc_uten = Path("../../eval_ood/script_pnp_utensils_200")
sc_stoys = Path("../../eval_ood/script_pnp_soft_toys_200")

sc_ood = [sc_sweep, sc_obj, sc_robj, sc_uten, sc_stoys]

dt_ft_sb = Path("../../eval_ood/deepthought_folding_table_stack_blocks_200")
dt_rd_dpnp = Path("../../eval_ood/deepthought_robot_desk_drawer_pnp_200")
dt_tk2_sb = Path("../../eval_ood/deepthought_toykitchen2_stack_blocks_200")
dt_tk1_ms = Path("../../eval_ood/deepthought_toykitchen1_many_skills_200")

emb_ood = [dt_ft_sb, dt_rd_dpnp, dt_tk2_sb, dt_tk1_ms]
# Test for embodiment 


env_ood = [mn_ft_sweep]

base_dirs = env_ood


for base_dir in base_dirs:


    if not base_dir.exists():
        print(f"Error: Main data directory '{base_dir}' does not exist.")
        exit()

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    dataset_identifier = f"{base_dir.parent.name}_{base_dir.name}".replace(os.sep, "_").replace(" ", "_")
    dataset_identifier = re.sub(r'[^\w\-_.]', '_', dataset_identifier)

    output_jsonl_path = Path(f"gvl_fixed_fewshot_results_{dataset_identifier}_{timestamp}.jsonl")
    output_cumulative_json_path = Path(f"gvl_fixed_fewshot_results_{dataset_identifier}_{timestamp}_CUMULATIVE.json")

    print(f"Incremental results (JSON Lines) will be saved to: {output_jsonl_path}")
    print(f"Cumulative complete results (pretty JSON) will be updated in: {output_cumulative_json_path}")

    # --- Load and Prepare the Dedicated Few-Shot Example ONCE ---
    prepared_few_shot_content_payload = []
    few_shot_example_name_for_log = "None" # Will be updated if example is loaded

    if FEW_SHOT_EXAMPLE_BASE_DIR.exists():
        print(f"Loading dedicated few-shot example from: {FEW_SHOT_EXAMPLE_BASE_DIR}")
        few_shot_example_name_for_log = FEW_SHOT_EXAMPLE_BASE_DIR.name
        
        fs_lang_file = FEW_SHOT_EXAMPLE_BASE_DIR / "lang.txt"
        fs_image_dir = FEW_SHOT_EXAMPLE_BASE_DIR / "images0"
        
        fs_task_description = "No specific description provided for the few-shot example task."
        if fs_lang_file.exists():
            with open(fs_lang_file, "r") as f:
                fs_task_description = f.read().strip()
        else:
            print(f"Warning: lang.txt not found in {FEW_SHOT_EXAMPLE_BASE_DIR}. Using default description for example.")

        fs_raw_image_paths = sorted(fs_image_dir.glob("im_*.jpg"))
        # fs_indexed_image_paths stores tuples of (original_index_in_sorted_list, Path_object)
        fs_indexed_image_paths = list(enumerate(fs_raw_image_paths)) 
        fs_total_frames_in_example_traj = len(fs_indexed_image_paths)

        if fs_total_frames_in_example_traj > 0:
            prepared_few_shot_content_payload.append("--- FEW-SHOT EXAMPLE START ---")
            prepared_few_shot_content_payload.append(f"Example Task: {fs_task_description}")
            
            try:
                # The first frame of the example trajectory is always 0%
                _ , fs_initial_path = fs_indexed_image_paths[0] 
                fs_initial_pil = Image.open(fs_initial_path)
                prepared_few_shot_content_payload.append("Initial robot scene for example task:")
                prepared_few_shot_content_payload.append(fs_initial_pil)
                prepared_few_shot_content_payload.append("In this initial scene for the example, task completion is 0%.")
            except Exception as e:
                print(f"Warning: Could not load initial image for dedicated few-shot example: {e}")
                prepared_few_shot_content_payload.append("Initial robot scene for example task: [Image unavailable due to error]")
                prepared_few_shot_content_payload.append("In this initial scene for the example, task completion is 0%.")

            prepared_few_shot_content_payload.append(f"\nNow, for the example task of '{fs_task_description}', here are some frames in random order and their task completion percentages:")
            
            # --- Use ALL frames from the example trajectory ---
            # Shuffle all indexed paths from the example trajectory
            shuffled_fs_indexed_paths = random.sample(fs_indexed_image_paths, fs_total_frames_in_example_traj)
            
            example_frames_actually_shown_count = 0
            for original_idx, img_path in shuffled_fs_indexed_paths: # Iterate through ALL shuffled example frames
                try:
                    img_pil = Image.open(img_path)
                    # Calculate percentage based on original index within the *full* example trajectory
                    percentage = 0
                    if fs_total_frames_in_example_traj > 1: 
                        percentage = round((original_idx / (fs_total_frames_in_example_traj - 1)) * 100)
                    elif fs_total_frames_in_example_traj == 1 and original_idx == 0: 
                        percentage = 0 

                    prepared_few_shot_content_payload.append(f"\nExample Frame (original name: {img_path.name}):")
                    prepared_few_shot_content_payload.append(img_pil)
                    prepared_few_shot_content_payload.append(f"Task Completion Percentage for this example frame: {percentage}%")
                    example_frames_actually_shown_count += 1
                except Exception as e:
                    print(f"Warning: Could not load/process dedicated few-shot example image {img_path.name}: {e}")
            
            prepared_few_shot_content_payload.append("\n--- FEW-SHOT EXAMPLE END ---\n")
            print(f"Successfully prepared dedicated few-shot example content with {example_frames_actually_shown_count} frames shown in prompt (from {fs_total_frames_in_example_traj} total in example).")
        else:
            print(f"Warning: No images found in {fs_image_dir} for the dedicated few-shot example. Proceeding without few-shot.")
            prepared_few_shot_content_payload = [] 
    else:
        print(f"Warning: Dedicated few-shot example directory '{FEW_SHOT_EXAMPLE_BASE_DIR}' not found. Proceeding without few-shot.")
        prepared_few_shot_content_payload = []

    # --- Load Target Trajectory Data ---
    target_traj_data = {}
    traj_dirs = sorted([d for d in base_dir.iterdir() if d.is_dir() and d.name.startswith('traj_')])
    for traj_dir in traj_dirs:
        lang_file = traj_dir / "lang.txt"
        image_dir = traj_dir / "images0"
        if not lang_file.exists() or not image_dir.exists():
            continue
        with open(lang_file, "r") as f:
            task_description_from_file = f.read().strip()
        image_paths = sorted(image_dir.glob("im_*.jpg"))
        if not image_paths:
            continue
        target_traj_data[traj_dir.name] = {
            "task_description": task_description_from_file,
            "image_paths": image_paths # List of Path objects
        }

    if not target_traj_data:
        print("No target trajectory data could be loaded. Exiting.")
        exit()
    print(f"Loaded data for {len(target_traj_data)} target trajectories.")


    model = genai.GenerativeModel('gemini-1.5-pro-latest')
    all_trajectories_results_in_memory = []

    def save_to_jsonl(result_dict, filepath):
        try:
            with open(filepath, "a") as f:
                json.dump(result_dict, f)
                f.write('\n')
        except Exception as e:
            print(f"Error saving to JSONL for {result_dict.get('traj_name', 'Unknown Trajectory')}: {e}")

    def save_cumulative_json(all_results_list, filepath):
        try:
            with open(filepath, "w") as f:
                json.dump(all_results_list, f, indent=2)
        except Exception as e:
            print(f"Error saving cumulative JSON to {filepath}: {e}")

    # --- Process Each Target Trajectory ---
    for traj_index, (target_traj_name, data_item) in enumerate(target_traj_data.items()):
        print(f"\n--- Processing Trajectory {traj_index + 1}/{len(target_traj_data)} (Overall {len(all_trajectories_results_in_memory) + 1} / {len(target_traj_data)}): {target_traj_name} ---")
        

        
        target_traj_data
        
        target_task_description = data_item["task_description"]
        target_all_image_paths = data_item["image_paths"]
        target_trajectory_length = len(target_all_image_paths)

        current_traj_result = {
            "traj_name": target_traj_name,
            "task_description": target_task_description,
            "trajectory_length": target_trajectory_length,
            "few_shot_example_source": few_shot_example_name_for_log if prepared_few_shot_content_payload else "None",
            "shuffled_target_frame_order_to_gemini": [],
            "gemini_response_text": None,
            "parsed_gemini_outputs": []
        }

        if target_trajectory_length == 0:
            # ... (skip logic as before)
            print(f"Skipping {target_traj_name}: No images found.")
            current_traj_result["gemini_response_text"] = "Skipped: No images found for target."
            all_trajectories_results_in_memory.append(current_traj_result)
            save_to_jsonl(current_traj_result, output_jsonl_path)
            save_cumulative_json(all_trajectories_results_in_memory, output_cumulative_json_path)
            #time.sleep(1.5)
            continue

        try:
            target_initial_img_path = target_all_image_paths[0]
            target_initial_pil = Image.open(target_initial_img_path)
        except Exception as e:
            # ... (skip logic as before)
            print(f"Error loading initial image for target {target_traj_name}: {e}. Skipping.")
            current_traj_result["gemini_response_text"] = f"Skipped: Error loading initial image for target - {e}"
            all_trajectories_results_in_memory.append(current_traj_result)
            save_to_jsonl(current_traj_result, output_jsonl_path)
            save_cumulative_json(all_trajectories_results_in_memory, output_cumulative_json_path)
            #time.sleep(1.5)
            continue
            
        # --- Construct Full Payload ---
        # Start with the pre-prepared few-shot example content (if any)
        content_payload = list(prepared_few_shot_content_payload) # Make a copy

        # --- Add Target Trajectory Content ---
        content_payload.append("--- TARGET TASK START ---")
        prompt_intro = (
            f"Now, for the new task of '{target_task_description}'.\n"
            f"The task completion percentages are between 0 and 100, where 100 "
            f"corresponds to full task completion.\n"
            f"Note that the following frames (from the target task) are in random order, so please pay attention to the individual frames "
            f"when reasoning about task completion percentage.\n\n"
        )
        content_payload.append(prompt_intro)
        content_payload.append("Initial robot scene for target task:")
        content_payload.append(target_initial_pil)
        content_payload.append("\nIn this initial scene for the target task, the task completion percentage is 0.\n\n")
        
        content_payload.append(
            f"Output the task completion percentage for the following frames (from the target task) that are presented in random order.\n"
            f"For each frame, format your response as follow: Frame {{i}}:\n" # Model generates {i}
            f"Frame Description: {{description}}, Task Completion Percentage: {{percentage}}%\n\n"
        )

        # Shuffle target frames for query
        target_paths_to_shuffle = target_all_image_paths[:] # Make a copy
        random.shuffle(target_paths_to_shuffle)
        
        shuffled_target_pil_and_filenames_data = []
        for img_path in target_paths_to_shuffle:
            try:
                img = Image.open(img_path)
                shuffled_target_pil_and_filenames_data.append({"image": img, "filename": img_path.name})
            except Exception as e:
                print(f"Warning: Error loading target image {img_path.name} for shuffling: {e}. Skipping this image.")

        if not shuffled_target_pil_and_filenames_data:
            # ... (skip logic as before)
            print(f"No target query images could be loaded for {target_traj_name}. Skipping API call.")
            current_traj_result["shuffled_target_frame_order_to_gemini"] = [p.name for p in target_paths_to_shuffle]
            current_traj_result["gemini_response_text"] = "Skipped: No target images loaded into shuffled set for API call."
            all_trajectories_results_in_memory.append(current_traj_result)
            save_to_jsonl(current_traj_result, output_jsonl_path)
            save_cumulative_json(all_trajectories_results_in_memory, output_cumulative_json_path)
            #time.sleep(1.5)
            continue

        shuffled_target_filenames_as_sent_to_gemini = [item['filename'] for item in shuffled_target_pil_and_filenames_data]
        current_traj_result["shuffled_target_frame_order_to_gemini"] = shuffled_target_filenames_as_sent_to_gemini

        for item in shuffled_target_pil_and_filenames_data:
            content_payload.append(item['image'])
        content_payload.append("\n--- TARGET TASK END ---")

    

        try:
            response = model.generate_content(content_payload, request_options={"timeout": 300})
            current_traj_result["gemini_response_text"] = response.text
        
            # --- MODIFIED PARSING LOGIC ---
            parsed_outputs = []
            response_lines = response.text.strip().split('\n')
            
            shuffled_image_index = 0 # To keep track of which shuffled image we're expecting data for

            i = 0
            while i < len(response_lines):
                line = response_lines[i].strip()
                current_parsed_entry = {"model_output_line": line, "original_filename": None, "predicted_percentage": None}

                if line.startswith("Frame "): # Catches "Frame 1:", "Frame 2:", etc.
                    if shuffled_image_index < len(shuffled_target_filenames_as_sent_to_gemini):
                        current_parsed_entry["original_filename"] = shuffled_target_filenames_as_sent_to_gemini[shuffled_image_index]
                    
                    # Check the next line for the description and percentage, if it exists
                    if (i + 1) < len(response_lines):
                        next_line = response_lines[i+1].strip()
                        if next_line.startswith("Frame Description:"):
                            current_parsed_entry["model_output_line"] += "\n" + next_line # Combine the lines
                            match = re.search(r"Task Completion Percentage:\s*(\d+)", next_line)
                            if match:
                                try:
                                    current_parsed_entry["predicted_percentage"] = int(match.group(1))
                                except ValueError:
                                    print(f"Warning: Could not parse percentage from '{match.group(1)}' in line: {next_line}")
                            i += 1 # Consume the next line as it was part of this frame's data
                        else:
                            # The "Frame X:" line didn't have a description line immediately following.
                            # Try to find percentage on the "Frame X:" line itself (less likely given the new format)
                            match_on_frame_line = re.search(r"Task Completion Percentage:\s*(\d+)", line)
                            if match_on_frame_line:
                                try:
                                    current_parsed_entry["predicted_percentage"] = int(match_on_frame_line.group(1))
                                except ValueError:
                                    print(f"Warning: Could not parse percentage from '{match_on_frame_line.group(1)}' in line: {line}")
                    else:
                        # "Frame X:" was the last line, try to find percentage on it
                        match_on_frame_line = re.search(r"Task Completion Percentage:\s*(\d+)", line)
                        if match_on_frame_line:
                            try:
                                current_parsed_entry["predicted_percentage"] = int(match_on_frame_line.group(1))
                            except ValueError:
                                print(f"Warning: Could not parse percentage from '{match_on_frame_line.group(1)}' in line: {line}")

                    parsed_outputs.append(current_parsed_entry)
                    shuffled_image_index += 1
                
                # If the line is a "Frame Description" but we didn't catch it via a preceding "Frame X:"
                # (This case handles if the model *only* outputs description lines without "Frame X:" prefixes,
                # or if the "Frame X:" was malformed and skipped by the 'if line.startswith("Frame ")' above)
                # This part might need adjustment if the model is *very* inconsistent.
                # For now, we prioritize the "Frame X:" structure.
                # If you find many "Frame Description" lines are orphaned, we might need to make this more robust.
                # For the provided output, the "Frame X:" followed by "Frame Description:" is the dominant pattern.

                i += 1 # Move to the next line in response_lines
                
            current_traj_result["parsed_gemini_outputs"] = parsed_outputs

        except Exception as e:
            error_message = f"An error occurred during API call for {target_traj_name}: {e}"
            print(f"\n{error_message}")
            current_traj_result["gemini_response_text"] = error_message
            if hasattr(e, 'response') and hasattr(e.response, 'prompt_feedback'):
                current_traj_result["gemini_response_text"] += f" | Prompt Feedback: {e.response.prompt_feedback}"
        
        all_trajectories_results_in_memory.append(current_traj_result)
        save_to_jsonl(current_traj_result, output_jsonl_path)
        save_cumulative_json(all_trajectories_results_in_memory, output_cumulative_json_path)
        print(f"Result for {target_traj_name} saved to {output_jsonl_path} and {output_cumulative_json_path} updated.")
        print("-" * 80)
        
        print(f"Pausing for 1.5 seconds before next trajectory...")
        #time.sleep(1.5)

    print(f"\nAll trajectories processed.")
    print(f"Incremental results (JSON Lines) are in: {output_jsonl_path}")
    print(f"Final cumulative results (pretty JSON) are in: {output_cumulative_json_path}")
    if not all_trajectories_results_in_memory:
        print("No results were processed.")