import ast
import os
import json
import pandas as pd
import numpy as np
import csv


class SR3DPlusPreprocess:
    def __init__(self, scene_name, folder_path: str, json_filename="_annotation.json"):
        self.folder_path = folder_path
        self.json_path = os.path.join(folder_path, scene_name+json_filename)
        with open(self.json_path, "r") as f:
            self.data = json.load(f)
        self.scene_id = scene_name
        
        


    
    def get_data_dicts(self):
        return self.data
    
    def get_json_path(self):
        """Return the JSON file path for a given scene_id"""
        return os.path.join(self.folder_path, f"{self.scene_id}_object.json")
    
    

    def load_scene_object(self, object_name):
        scene_id = self.scene_id
        json_path = self.get_json_path()
        if not os.path.exists(json_path):
            print(f"JSON for scene_id {scene_id} does not exist.")
            return None

        with open(json_path, "r") as f:
            data = json.load(f)

        if object_name not in data:
            print(f"Object '{object_name}' not found in scene {scene_id}.")
            return None

        # convert lists back to numpy arrays
        return [np.array(arr) for arr in data[object_name]]


    def save_scene_object(self, object_name, object_data):
        json_path = self.get_json_path()

        # Load existing data if file exists, otherwise create an empty dict
        if os.path.exists(json_path):
            with open(json_path, "r") as f:
                data = json.load(f)
        else:
            data = {}

        # Ensure object_data is a list of numpy arrays
        if isinstance(object_data, np.ndarray):
            object_data = [object_data]

        # Convert numpy arrays to lists for JSON serialization
        data[object_name] = [arr.tolist() for arr in object_data]

        # Save back to JSON
        with open(json_path, "w") as f:
            json.dump(data, f)
          
 

    def append_scene_command(self, command: str, obj_dict: dict, true_target,is_easy, is_view_dep, base_csv_name="scene_commands+.csv"):
        scene_id = self.scene_id
        os.makedirs(self.folder_path, exist_ok=True)

        # Find existing CSV files for this scene
        csv_files = [
            fname for fname in os.listdir(self.folder_path)
            if fname.startswith(scene_id) and fname.endswith(".csv") and base_csv_name.split(".")[0] in fname
        ]

        if csv_files:
            # Sort files so the last one is the most recent
            csv_files.sort()
            latest_csv = csv_files[-1]
            csv_path = os.path.join(self.folder_path, latest_csv)
            file_exists = True
        else:
            # Create a new CSV file
            csv_path = os.path.join(self.folder_path, f"{scene_id}_{base_csv_name}")
            file_exists = False

        # Append or create with header
        with open(csv_path, "a", newline="", encoding="utf-8") as f:
            writer = csv.DictWriter(
                f, 
                fieldnames=[
                    "command", 
                    "raw_output", 
                    "main_object", 
                    "related_objects", 
                    "relation", 
                    "orientation_importance",
                    "true_target",
                    "is_easy",
                    "is_view_dep"
                ]
            )
            if not file_exists:  # new file → write header
                writer.writeheader()
            writer.writerow({
                "command": command,
                "raw_output": obj_dict.get("raw_output", ""),
                "main_object": obj_dict.get("main_object", ""),
                "related_objects": obj_dict.get("related_objects", ""),
                "relation": obj_dict.get("relation", ""),
                "orientation_importance": obj_dict.get("orientation_importance", ""),
                "true_target": true_target,
                "is_easy": is_easy,
                "is_view_dep": is_view_dep
            })

        print(f"Saved row to {csv_path}")
        
    def load_scene_commands(self, base_csv_name="scene_commands+.csv"):
        scene_id = self.scene_id
        os.makedirs(self.folder_path, exist_ok=True)

        # Find existing CSV files for this scene
        csv_files = [
            fname for fname in os.listdir(self.folder_path)
            if fname.startswith(scene_id) and fname.endswith(".csv") and base_csv_name.split(".")[0] in fname
        ]

        all_commands = []

        if csv_files:
            # Sort files so the last one is the most recent
            csv_files.sort()
            for csv_file in csv_files:
                csv_path = os.path.join(self.folder_path, csv_file)
                with open(csv_path, "r", newline="", encoding="utf-8") as f:
                    reader = csv.DictReader(f)
                    for row in reader:
                        # Convert fields that should not be strings
                        processed_row = {
                            "command": row["command"],
                            "raw_output": row["raw_output"],
                            "main_object": row["main_object"],
                            "related_objects": ast.literal_eval(row["related_objects"]) if row["related_objects"] else [],
                            "relation": row["relation"],
                            "orientation_importance": ast.literal_eval(row["orientation_importance"]) if row["orientation_importance"] else 0.0,
                            "true_target": ast.literal_eval(row["true_target"]),
                            "is_easy": row["is_easy"].strip().lower() == "true",
                            "is_view_dep": row["is_view_dep"].strip().lower() == "true"
                        }
                        all_commands.append(processed_row)
            print(f"Loaded {len(all_commands)} commands from {len(csv_files)} files.")
        else:
            print("No CSV files found for this scene.")

        
        return all_commands


    def get_view_json_path(self):
        """Return the JSON file path for storing object views"""
        return os.path.join(self.folder_path, f"{self.scene_id}_views.json")

    def save_object_view(self, object_name: str, orientation: str, values: list[float]):
        """
        Save a list of floats for a given object and orientation into JSON.
        """
        json_path = self.get_view_json_path()

        # Load existing data if exists
        if os.path.exists(json_path):
            with open(json_path, "r") as f:
                data = json.load(f)
        else:
            data = {}

        key = f"{object_name}_{orientation}"
        data[key] = values  # list of floats

        with open(json_path, "w") as f:
            json.dump(data, f)

        print(f"Saved view for {key} into {json_path}")

    def load_object_view(self, object_name: str, orientation: str):
        """
        Load a list of integers for a given object and orientation.
        """
        json_path = self.get_view_json_path()

        if not os.path.exists(json_path):
            print(f"No views JSON file found for scene {self.scene_id}.")
            return None

        with open(json_path, "r") as f:
            data = json.load(f)

        key = f"{object_name}_{orientation}"
        if key not in data:
            print(f"No view found for {key} in scene {self.scene_id}.")
            return None

        # convert to list of ints
        return [int(x) for x in data[key]]


