from dotenv import load_dotenv
import os
import google.generativeai as genai
from PIL import Image
from pathlib import Path
import random
import json
import re
from datetime import datetime
import time

load_dotenv()
api_key = os.getenv('GOOGLE_API_KEY')
if not api_key:
    print("Error: GOOGLE_API_KEY not found.")
    exit()
genai.configure(api_key=api_key)


tk_mulit = Path("../../eval_ood/bridge_many_skills_exp") 
in_dist = [tk_mulit]

td_fold = Path("../../eval_ood/tabletop_dark_wood_fold_cloth_pnp_200")
mn_ft_sweep = Path("../../eval_ood/minsky_folding_table_white_tray_sweep_granular_100") # ADJUST
ft_fold = Path("../../eval_ood/folding_table_fold_cloth_pnp_200")
lm_pnp = Path("../../eval_ood/laundry_machine_pnp_sweep_200")
rd_fold = Path("../../eval_ood/robot_desk_fold_cloth_200")



env_ood = [td_fold,mn_ft_sweep, ft_fold, lm_pnp, rd_fold]

sc_sweep = Path("../../eval_ood/script_sweep_pnp_200")
sc_obj = Path("../../eval_ood/script_pnp_objects_200")
sc_robj = Path("../../eval_ood/script_pnp_rigid_objects_200")
sc_uten = Path("../../eval_ood/script_pnp_utensils_200")
sc_stoys = Path("../../eval_ood/script_pnp_soft_toys_200")

sc_ood = [sc_sweep, sc_obj, sc_robj, sc_uten, sc_stoys]

dt_ft_sb = Path("../../eval_ood/deepthought_folding_table_stack_blocks_200")
dt_rd_dpnp = Path("../../eval_ood/deepthought_robot_desk_drawer_pnp_200")
dt_tk2_sb = Path("../../eval_ood/deepthought_toykitchen2_stack_blocks_200")
dt_tk1_ms = Path("../../eval_ood/deepthought_toykitchen1_many_skills_200")

emb_ood = [dt_ft_sb, dt_rd_dpnp, dt_tk2_sb, dt_tk1_ms]
# Test for embodiment 


remaining = [rd_fold, dt_tk1_ms]

base_dirs = remaining


for base_dir in base_dirs:



    if not base_dir.exists():
        print(f"Error: Base directory '{base_dir}' does not exist.")
        exit()

    # --- Filename Configuration ---
    # OPTION 1: Generate new filenames for a fresh run
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    dataset_identifier = f"{base_dir.parent.name}_{base_dir.name}".replace(os.sep, "_").replace(" ", "_")
    dataset_identifier = re.sub(r'[^\w\-_.]', '_', dataset_identifier)
    output_jsonl_path_str = f"gvl_results_{dataset_identifier}_{timestamp}.jsonl"
    output_cumulative_json_path_str = f"gvl_results_{dataset_identifier}_{timestamp}_CUMULATIVE.json"

    # OPTION 2: Manually set filenames to resume a specific previous run
    # Comment out the lines above and uncomment these, then set the correct filenames:
    # output_jsonl_path_str = "gvl_results_eval_ood_minsky_folding_table_white_tray_sweep_granular_100_20231115_103000.jsonl"
    # output_cumulative_json_path_str = "gvl_results_eval_ood_minsky_folding_table_white_tray_sweep_granular_100_20231115_103000_CUMULATIVE.json"

    output_jsonl_path = Path(output_jsonl_path_str)
    output_cumulative_json_path = Path(output_cumulative_json_path_str)

    print(f"Incremental results (JSON Lines) will be saved/appended to: {output_jsonl_path}")
    print(f"Cumulative complete results (pretty JSON) will be updated in: {output_cumulative_json_path}")

    # --- Load Already Processed Trajectories (for resuming) ---
    processed_traj_names = set()
    all_trajectories_results_in_memory = [] 

    if output_jsonl_path.exists():
        print(f"Found existing results file: {output_jsonl_path}. Attempting to load processed trajectories.")
        try:
            with open(output_jsonl_path, "r") as f_existing:
                for line in f_existing:
                    try:
                        existing_result = json.loads(line)
                        if "traj_name" in existing_result:
                            processed_traj_names.add(existing_result["traj_name"])
                            all_trajectories_results_in_memory.append(existing_result) # Load into memory
                    except json.JSONDecodeError:
                        print(f"Warning: Could not decode a line from existing .jsonl file: {line.strip()}")
            print(f"Loaded {len(processed_traj_names)} previously processed trajectory names.")
            if all_trajectories_results_in_memory:
                # If we loaded from JSONL, ensure the cumulative JSON is also up-to-date at the start
                save_cumulative_json(all_trajectories_results_in_memory, output_cumulative_json_path)
                print(f"Cumulative JSON {output_cumulative_json_path} updated with existing results.")

        except Exception as e:
            print(f"Error reading existing results file {output_jsonl_path}: {e}. Starting fresh for this file.")
            processed_traj_names = set() # Reset if reading fails
            all_trajectories_results_in_memory = []


    # --- Load Target Trajectory Data ---
    traj_data = {}
    traj_dirs = sorted([d for d in base_dir.iterdir() if d.is_dir() and d.name.startswith('traj_')])
    for traj_dir in traj_dirs:
        # ... (your existing traj_data loading logic) ...
        lang_file = traj_dir / "lang.txt"
        image_dir = traj_dir / "images0"
        if not lang_file.exists() or not image_dir.exists():
            print(f"Warning: Missing lang.txt or images0 in {traj_dir}. Skipping.")
            continue
        with open(lang_file, "r") as f:
            task_description_from_file = f.read().strip()
        image_paths = sorted(image_dir.glob("im_*.jpg"))
        if not image_paths:
            print(f"Warning: No images found in {image_dir} for {traj_dir}. Skipping.")
            continue
        traj_data[traj_dir.name] = {
            "task_description": task_description_from_file,
            "image_paths": image_paths
        }

    if not traj_data:
        print("No target trajectory data could be loaded. Exiting.")
        exit()

    print(f"Found {len(traj_data)} total trajectories in the dataset.")
    # Filter out already processed trajectories from the list to be processed
    trajectories_to_process = {name: data for name, data in traj_data.items() if name not in processed_traj_names}
    if not trajectories_to_process:
        print("All trajectories from the dataset have already been processed according to the results file. Exiting.")
        exit()
    print(f"Will process {len(trajectories_to_process)} new/remaining trajectories.")


    model = genai.GenerativeModel('gemini-1.5-pro-latest')

    def save_to_jsonl(result_dict, filepath):
        try:
            with open(filepath, "a") as f:
                json.dump(result_dict, f)
                f.write('\n')
        except Exception as e:
            print(f"Error saving to JSONL for {result_dict.get('traj_name', 'Unknown Trajectory')}: {e}")

    def save_cumulative_json(all_results_list, filepath):
        try:
            with open(filepath, "w") as f:
                json.dump(all_results_list, f, indent=2)
        except Exception as e:
            print(f"Error saving cumulative JSON to {filepath}: {e}")

    # --- Process Each Trajectory (GVL Style) ---
    # Iterate over the filtered list of trajectories to process
    for traj_index, (traj_name, data_item) in enumerate(trajectories_to_process.items()):
        print(f"\n--- Processing Trajectory {traj_index + 1}/{len(trajectories_to_process)} (Overall {len(all_trajectories_results_in_memory) + 1} / {len(traj_data)}): {traj_name} ---")
        
        # --- Check if already processed ---
        if traj_name in processed_traj_names:
            print(f"Skipping {traj_name} as it was already processed in a previous run.")
            continue 

        task_description = data_item["task_description"]
        all_image_paths = data_item["image_paths"]
        trajectory_length = len(all_image_paths)

        current_traj_result = {
            "traj_name": traj_name,
            "task_description": task_description,
            "trajectory_length": trajectory_length,
            "shuffled_input_frame_order_to_gemini": [],
            "gemini_response_text": None,
            "parsed_gemini_outputs": []
        }


        if not all_image_paths:
            print(f"Skipping {traj_name}: No images found.")
            current_traj_result["gemini_response_text"] = "Skipped: No images found."
            all_trajectories_results_in_memory.append(current_traj_result)
            save_to_jsonl(current_traj_result, output_jsonl_path)
            save_cumulative_json(all_trajectories_results_in_memory, output_cumulative_json_path)
            #time.sleep(3.0)
            continue

        try:
            initial_image_path = all_image_paths[0]
            initial_image_pil = Image.open(initial_image_path)
        except Exception as e:
            print(f"Error loading initial image for {traj_name}: {e}. Skipping.")
            current_traj_result["gemini_response_text"] = f"Skipped: Error loading initial image - {e}"
            all_trajectories_results_in_memory.append(current_traj_result)
            save_to_jsonl(current_traj_result, output_jsonl_path)
            save_cumulative_json(all_trajectories_results_in_memory, output_cumulative_json_path)
            #time.sleep(1.5)
            continue

        query_frames_to_shuffle_paths = all_image_paths[:]
    
        random.shuffle(query_frames_to_shuffle_paths)
        shuffled_query_pil_and_filenames_data = []
        for img_path in query_frames_to_shuffle_paths:
            try:
                img = Image.open(img_path)
                shuffled_query_pil_and_filenames_data.append({"image": img, "filename": img_path.name})
            except Exception as e:
                print(f"Warning: Error loading image {img_path.name} for shuffling: {e}. Skipping this image.")

        if not shuffled_query_pil_and_filenames_data:
            print(f"No query images could be loaded for shuffled set in {traj_name}. Skipping.")
            current_traj_result["shuffled_input_frame_order_to_gemini"] = [p.name for p in query_frames_to_shuffle_paths]
            current_traj_result["gemini_response_text"] = "Skipped: No images loaded into shuffled set."
            all_trajectories_results_in_memory.append(current_traj_result)
            save_to_jsonl(current_traj_result, output_jsonl_path)
            save_cumulative_json(all_trajectories_results_in_memory, output_cumulative_json_path)
            #time.sleep(1.5)
            continue

        shuffled_filenames_as_sent_to_gemini = [item['filename'] for item in shuffled_query_pil_and_filenames_data]
        current_traj_result["shuffled_input_frame_order_to_gemini"] = shuffled_filenames_as_sent_to_gemini

        content_payload = []
        prompt_intro = (
            f"You are an expert roboticist tasked to predict task completion\n"
            f"percentages for frames of a robot for the task of '{task_description}'.\n"
            f"The task completion percentages are between 0 and 100, where 100\n"
            f"corresponds to full task completion.\n"
            f"Note that the following frames are in random order, so please pay attention to the individual frames\n"
            f"when reasoning about task completion percentage.\n\n"
        )
        content_payload.append(prompt_intro)
        content_payload.append("Initial robot scene: ")
        content_payload.append(initial_image_pil)
        content_payload.append("\nIn the initial robot scene, the task completion percentage is 0.\n\n")
        content_payload.append(
            f"Now, for the task of '{task_description}', output the task completion\n"
            f"percentage for the following frames that are presented in random\n"
            f"order. For each frame, format your response as follow: Frame {{i}}:\n"
            f"Frame Description: {{description}}, Task Completion Percentage: {{percentage}}%\n\n"
        )
        for i_img, item_img in enumerate(shuffled_query_pil_and_filenames_data):
            content_payload.append(item_img['image'])
            if i_img < len(shuffled_query_pil_and_filenames_data) - 1:
                content_payload.append("\n")
        # API Call and Parsing
        try:
            response = model.generate_content(content_payload, request_options={"timeout": 300})
            current_traj_result["gemini_response_text"] = response.text
        
            parsed_outputs = []
            response_lines = response.text.strip().split('\n')
            shuffled_image_index = 0
            i_line = 0
            while i_line < len(response_lines):
                line = response_lines[i_line].strip()
                current_model_output_text_for_frame = line
                predicted_percentage_for_frame = None
                original_filename_for_frame = None
                if line.startswith("Frame "):
                    if shuffled_image_index < len(shuffled_filenames_as_sent_to_gemini):
                        original_filename_for_frame = shuffled_filenames_as_sent_to_gemini[shuffled_image_index]
                    if (i_line + 1) < len(response_lines):
                        next_line = response_lines[i_line+1].strip()
                        if next_line.startswith("Frame Description:"):
                            current_model_output_text_for_frame += "\n" + next_line
                            match = re.search(r"Task Completion Percentage:\s*(\d+)", next_line)
                            if match:
                                try: predicted_percentage_for_frame = int(match.group(1))
                                except ValueError: print(f"Warning: Could not parse percentage from '{match.group(1)}' in line: {next_line}")
                            i_line += 1
                        else:
                            match_on_frame_line = re.search(r"Task Completion Percentage:\s*(\d+)", line)
                            if match_on_frame_line:
                                try: predicted_percentage_for_frame = int(match_on_frame_line.group(1))
                                except ValueError: print(f"Warning: Could not parse percentage from '{match_on_frame_line.group(1)}' in line: {line}")
                    else:
                        match_on_frame_line = re.search(r"Task Completion Percentage:\s*(\d+)", line)
                        if match_on_frame_line:
                            try: predicted_percentage_for_frame = int(match_on_frame_line.group(1))
                            except ValueError: print(f"Warning: Could not parse percentage from '{match_on_frame_line.group(1)}' in line: {line}")
                    parsed_outputs.append({
                        "model_output_line": current_model_output_text_for_frame,
                        "original_filename": original_filename_for_frame,
                        "predicted_percentage": predicted_percentage_for_frame
                    })
                    shuffled_image_index += 1
                i_line += 1
            current_traj_result["parsed_gemini_outputs"] = parsed_outputs

        except Exception as e:
            error_message = f"An error occurred during API call for {traj_name}: {e}"
            print(f"\n{error_message}")
            current_traj_result["gemini_response_text"] = error_message
            if hasattr(e, 'response') and hasattr(e.response, 'prompt_feedback'):
                current_traj_result["gemini_response_text"] += f" | Prompt Feedback: {e.response.prompt_feedback}"
        # --- End of single trajectory processing ---

        all_trajectories_results_in_memory.append(current_traj_result)
        processed_traj_names.add(traj_name) # Add to set after successful processing or skip
        save_to_jsonl(current_traj_result, output_jsonl_path)
        save_cumulative_json(all_trajectories_results_in_memory, output_cumulative_json_path)
        print(f"Result for {traj_name} saved to {output_jsonl_path} and {output_cumulative_json_path} updated.")
        print("-" * 80)
        
        print(f"Pausing for 1.5 seconds before next trajectory...")
        #time.sleep(1.5)

    print(f"\nAll trajectories processed (or skipped if previously done).")
    print(f"Incremental results (JSON Lines) are in: {output_jsonl_path}")
    print(f"Final cumulative results (pretty JSON) are in: {output_cumulative_json_path}")
    if not all_trajectories_results_in_memory:
        print("No results were processed in this session (could be all were pre-existing).")