"""
LIBERO Extra Description dataloader that inherits from LiberoOriginalDataset
but replaces language instructions with ones from libero_extra_desc.json.
"""

import json
import random
from pathlib import Path
from typing import Dict, Any, Optional
from dataclasses import dataclass

from .libero_original_dataloader import LiberoOriginalDataset, LiberoOriginalDataConfig


@dataclass
class LiberoExtraDescDataConfig(LiberoOriginalDataConfig):
    """Configuration for LIBERO Extra Description dataset loading."""
    
    extra_desc_json_path: str = "./config/data/libero90_splits/libero_extra_desc.json"


class LiberoExtraDescDataset(LiberoOriginalDataset):
    """LIBERO dataset with extra descriptions from JSON file."""

    def __init__(self, config: LiberoExtraDescDataConfig):
        self.extra_desc_json_path = Path(config.extra_desc_json_path)
        self.extra_descriptions = self._load_extra_descriptions()
        
        # Initialize parent class
        super().__init__(config)

    def _load_extra_descriptions(self) -> Dict[str, Dict[str, list]]:
        """Load extra descriptions from JSON file."""
        if not self.extra_desc_json_path.exists():
            raise FileNotFoundError(f"Extra descriptions file not found: {self.extra_desc_json_path}")
        
        with open(self.extra_desc_json_path, 'r') as f:
            descriptions = json.load(f)
        
        print(f"Loaded extra descriptions for {len(descriptions)} tasks")
        return descriptions

    def _get_task_instruction(self, task_name: str) -> str:
        """Get a random instruction for the task based on current split."""
        # The task_name comes from _extract_task_description which removes scene info
        # e.g., "open the bottom drawer of the cabinet"
        # But JSON keys are full names like "KITCHEN_SCENE1_open_the_bottom_drawer_of_the_cabinet"
        
        # Convert task description back to underscore format for matching
        task_underscore = task_name.replace(" ", "_")
        
        # Find matching task in extra descriptions by checking if JSON key ends with our task
        matching_task = None
        for json_task_name in self.extra_descriptions.keys():
            # Check if the JSON key ends with our task (after scene prefix)
            if json_task_name.endswith(task_underscore):
                matching_task = json_task_name
                break
        
        if not matching_task:
            # Fallback to original task description if no match found
            print(f"Warning: No extra description found for task '{task_name}', using original")
            return task_name
        
        # Select instruction based on split
        split_key = "train" if self.split == "train" else "eval"

        if split_key in self.extra_descriptions[matching_task]:
            instructions = self.extra_descriptions[matching_task][split_key]
            if instructions:
                return random.choice(instructions)
        
        # Fallback to original task description
        print(f"Warning: No {split_key} instructions found for task '{matching_task}', using original")
        return task_name

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        """Get a single transition with replaced language instruction."""
        # Get the original item
        item = super().__getitem__(idx)
        
        # Replace the task description with extra description
        original_task_desc = item["task_descriptions"]
        new_task_desc = self._get_task_instruction(original_task_desc)
        item["task_descriptions"] = new_task_desc.lower()
        
        return item 