"""
Module: sensor_config_loader
----------------------------
This module provides functionality to load and parse the sensor mapping YAML configuration file.
It includes methods to retrieve dataset-specific sensor mappings and list all available datasets.
Additionally, it constructs a global mapping of unique IDs for all body parts, sensors, and axes.

Usage:
    from sensor_config_loader import SensorConfigLoader

    # Initialize the loader
    config_loader = SensorConfigLoader()

    # Retrieve dataset-specific sensors
    realworld_sensors = config_loader.get_dataset_sensors("realworld2016")

    # Retrieve all dataset names
    datasets = config_loader.get_all_datasets()

    # Retrieve global mappings
    global_mappings = config_loader.get_global_mappings()

    # Retrieve embedded configuration data
    embedded_config = config_loader.get_embedded_config()
"""

import os
import yaml
import json
import logging

from google import genai
from google.genai import types

from utils.path_utils import get_directory_path


# Configure basic logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

class SensorConfigLoader:
    """
    SensorConfigLoader
    ------------------
    A class to load and manage sensor mapping configurations from a YAML file. This loader
    provides methods to retrieve sensor configurations for specific datasets or retrieve all
    dataset names. It also constructs a global mapping of unique IDs for body parts, sensors, and axes.

    Attributes:
        config_path (str): The full path to the configuration file.
        config_data (dict): Parsed content of the configuration file.
        global_mappings (dict): A global mapping of unique IDs for body parts, sensors, and axes.
        api_key (str): Google Generative AI API Key.
        proxy_url (str, optional): URL for HTTP/HTTPS proxy.
        cache_dir (str): Directory path for caching embeddings.
        embeddings_cache_path (str): Full path to the embeddings cache file.
        embedded_vectors (dict): Cached or freshly fetched embeddings.
    """

    def __init__(self, config_filename="sensor_mapping.yaml"):
        """
        Initialize the SensorConfigLoader.

        Parameters:
            config_filename (str): Name of the YAML configuration file. Default is "sensor_mapping.yaml".
            api_key (str, optional): Google Generative AI API Key. Required if fetching embeddings.
            proxy_url (str, optional): Proxy URL (e.g., "http://user:pass@host:port"). Sets both http_proxy and https_proxy.

        Raises:
            FileNotFoundError: If the configuration file is not found.
            ValueError: If there is an error parsing the YAML file or API key is missing when needed.
        """


        self.api_key = "AIzaSyCO_L2QrxTQOrt_l8SpnlL1R8l8Veb3Hns"
        self.proxy_url = "http://172.18.36.112:7890"

        self.config_path = os.path.join(get_directory_path("configs"), config_filename)
        self.config_data = self._load_config()
        self.embedded_vector = self.get_embedded_vector()
        self.global_mappings = self._construct_global_mappings()

    def _load_config(self):
        """
        Private method to load the YAML configuration file. Parses and stores the data.

        Returns:
            dict: Parsed content of the configuration file.

        Raises:
            FileNotFoundError: If the configuration file is not found at the specified path.
            ValueError: If there is an error while parsing the YAML file.
        """
        if not os.path.exists(self.config_path):
            raise FileNotFoundError(f"Config file not found at {self.config_path}")

        with open(self.config_path, "r", encoding="utf-8") as file:
            try:
                return yaml.safe_load(file)
            except yaml.YAMLError as e:
                raise ValueError(f"Error parsing YAML file: {e}") from e

    def _construct_global_mappings(self):
        """
        Private method to construct global mappings of unique IDs for body parts, sensors, and axes.

        The indices in the mappings start from 1 to ensure no conflicts with default zero values in embedding layers.
        Note: The all-zero token (0) is reserved for special tokens like CLS.

        Returns:
            dict: A dictionary containing global mappings with keys 'body_parts', 'sensors', and 'axes'.
        """
        body_parts = set()
        sensors = set()
        axes = set()

        for _, sensor_list in self.config_data.items():
            for sensor in sensor_list:
                body_parts.add(sensor["body_part"])
                sensors.add(sensor["sensor"])
                axes.add(sensor["axis"])

        return {
            "body_parts": {
                bp: idx + 1 for idx, bp in enumerate(sorted(body_parts))
            },  # IDs start from 1 for body parts
            "sensors": {
                sensor: idx + 1 for idx, sensor in enumerate(sorted(sensors))
            },  # IDs start from 1 for sensors
            "axes": {
                axis: idx + 1 for idx, axis in enumerate(sorted(axes))
            },  # IDs start from 1 for axes
        }

    def get_dataset_sensors(self, dataset_name):
        """
        Retrieve the sensor configuration for a specific dataset.

        Parameters:
            dataset_name (str): Name of the dataset (e.g., "realworld2016").

        Returns:
            list: A list of sensor configurations for the specified dataset. Returns an empty list if the dataset is not found.

        Raises:
            ValueError: If the configuration has not been loaded yet.
        """
        if self.config_data is None:
            raise ValueError(
                "Configuration not loaded. Ensure the file is properly initialized."
            )

        return self.config_data.get(dataset_name, [])

    def get_all_datasets(self):
        """
        Retrieve a list of all dataset names available in the configuration file.

        Returns:
            list: A list of dataset names.

        Raises:
            ValueError: If the configuration has not been loaded yet.
        """
        if self.config_data is None:
            raise ValueError(
                "Configuration not loaded. Ensure the file is properly initialized."
            )

        return list(self.config_data.keys())

    def get_global_mappings(self):
        """
        Retrieve the global mappings of unique IDs for body parts, sensors, and axes.

        Returns:
            dict: A dictionary containing global mappings with keys 'body_parts', 'sensors', and 'axes'.
        """
        return self.global_mappings

    def get_embedded_config(self):
        """
        Apply the global mappings to the configuration data to make it embedding-ready.

        Returns:
            dict: A dictionary where each dataset's sensor configuration has been transformed
                  to include the unique IDs for body parts, sensors, and axes.
        """
        embedded_config = {}
        for dataset, sensor_list in self.config_data.items():
            embedded_config[dataset] = [
                {
                    "body_part_id": self.global_mappings["body_parts"][
                        sensor["body_part"]
                    ],
                    "sensor_id": self.global_mappings["sensors"][sensor["sensor"]],
                    "axis_id": self.global_mappings["axes"][sensor["axis"]],
                }
                for sensor in sensor_list
            ]
        return embedded_config

    def get_embedded_vector(self):
        """Loads or fetches embedded vectors for sensor configurations.

        Attempts to load embeddings from a JSON cache file first. If the cache
        is missing or invalid, it fetches the embeddings from the Google AI API
        using the configured API key and optional proxy settings. Fetched
        embeddings are then saved back to the cache file.

        The cache file path is constructed based on the 'cache' directory
        retrieved via `get_directory_path`.

        Note:
            Proxy environment variables (`http_proxy`, `https_proxy`), if set
            by this method, are not unset within this method's scope and
            will persist for the process lifetime unless cleared elsewhere.

        Returns:
            dict: A dictionary mapping dataset names (str) to lists of
                  embedding vectors (list of float). Returns an empty dictionary
                  for datasets with no sensor configurations.

        Raises:
            ValueError: If the API key (`self.api_key`) was not provided during
                        initialization and the embeddings cache file cannot be
                        loaded successfully.
            IOError: If there is an issue reading from or writing to the cache file.
            Exception: Catches and logs other potential exceptions during API
                       interaction or caching, then re-raises them.
        """
        # --- Cache Path Definition ---
        # Determine the directory path for caching application data.
        cache_dir = get_directory_path("cache")
        # Construct the full path for the sensor embeddings cache file.
        cache_file_path = os.path.join(cache_dir, "sensor_embeddings_cache.json")
        # --- End Cache Path Definition ---

        # 1. Attempt to load embeddings from the cache file.
        if os.path.exists(cache_file_path):
            logging.info("Attempting to load embeddings from cache: %s", cache_file_path)
            try:
                with open(cache_file_path, 'r', encoding='utf-8') as f:
                    cached_embeddings = json.load(f)
                logging.info("Successfully loaded embeddings from cache.")
                return cached_embeddings
            except json.JSONDecodeError:
                logging.warning("Cache file %s is corrupted. Fetching from API.", cache_file_path)
            except IOError as e: # Catch specific IOError for reading
                logging.warning("Failed to read cache file %s: %s. Fetching from API.", cache_file_path, e)
            except Exception as e: # Catch other potential exceptions during load
                logging.warning("Unexpected error loading cache file %s: %s. Fetching from API.", cache_file_path, e)


        # 2. If cache miss or error, fetch embeddings from the Google AI API.
        logging.info("Cache not found or invalid. Fetching embeddings from Google AI.")
        if not self.api_key:
            raise ValueError("API key was not provided during initialization, and embeddings cache was not available.")

        # --- Proxy Setup ---
        # Configure proxy settings if a URL was provided.
        if self.proxy_url:
            logging.info("Setting http_proxy and https_proxy to: %s", self.proxy_url)
            # Note: These environment variables affect the entire process.
            os.environ["http_proxy"] = self.proxy_url
            os.environ["https_proxy"] = self.proxy_url
        # --- End Proxy Setup ---

        embeddings_result = {}
        # Initialize the Google Generative AI client using the stored API key.
        # The client typically respects standard proxy environment variables.
        client = genai.Client(api_key=self.api_key)

        # Iterate through datasets defined in the loaded configuration.
        for dataset_name, sensor_config_list in self.config_data.items():
            # Prepare sensor configuration data as strings for embedding.
            # Converts each sensor dict to its string representation, removing braces.
            content_strings = [str(sensor_dict).strip("{}") for sensor_dict in sensor_config_list]

            if not content_strings:
                logging.warning("No sensor configurations found for dataset '%s'. Skipping embedding.", dataset_name)
                embeddings_result[dataset_name] = []
                continue

            # Log the embedding request details.
            logging.info("Fetching embeddings for dataset: %s (%d items)", dataset_name, len(content_strings))

            # Call the Google AI API to get embeddings.
            # Specifies the model and task type for potentially better results.
            response = client.models.embed_content(
                model="models/text-embedding-004",
                contents=content_strings,
                config=types.EmbedContentConfig(task_type="CLUSTERING"),
            )
            # Extract the embedding values from the response.
            embeddings_result[dataset_name] = [embedding.values for embedding in response.embeddings]
            logging.info("Successfully fetched embeddings for %s", dataset_name)

        # 3. Save the fetched embeddings to the cache file.
        try:
            # Ensure the cache directory exists before writing.
            os.makedirs(cache_dir, exist_ok=True)
            logging.info("Saving fetched embeddings to cache: %s", cache_file_path)
            with open(cache_file_path, 'w', encoding='utf-8') as f:
                # Write JSON data with indentation for readability.
                json.dump(embeddings_result, f, indent=4)
        except IOError as e:
            # Log error if cache writing fails but proceed returning fetched data.
            logging.error("Failed to write cache file %s: %s", cache_file_path, e)
        except Exception as e:
            # Log unexpected errors during cache saving.
            logging.error("Unexpected error saving cache file %s: %s", cache_file_path, e)

        # Return the fetched embeddings, even if caching failed.
        return embeddings_result
