"""
Meta-World Extra Description dataloader that inherits from MetaWorldDataset
but replaces language instructions with ones from metaworld_extra_desc.json.
"""

import json
import random
from pathlib import Path
from typing import Dict, Any, Optional
from dataclasses import dataclass

from .meta_world_dataloader import MetaWorldDataset, MetaWorldDataConfig, MetaWorldPerTaskDataset, MetaWorldPerTaskDataConfig


@dataclass
class MetaWorldExtraDescDataConfig(MetaWorldDataConfig):
    """Configuration for Meta-World Extra Description dataset loading."""
    
    extra_desc_json_path: str = "./config/data/meta-world_splits/metaworld_extra_desc.json"


@dataclass
class MetaWorldExtraDescPerTaskDataConfig(MetaWorldPerTaskDataConfig):
    """Configuration for Meta-World per-task Extra Description dataset loading."""
    
    extra_desc_json_path: str = "./config/data/meta-world_splits/metaworld_extra_desc.json"


class MetaWorldExtraDescDataset(MetaWorldDataset):
    """Meta-World dataset with extra descriptions from JSON file."""

    def __init__(self, config: MetaWorldExtraDescDataConfig):
        self.extra_desc_json_path = Path(config.extra_desc_json_path)
        self.extra_descriptions = self._load_extra_descriptions()
        
        # Initialize parent class
        super().__init__(config)

    def _load_extra_descriptions(self) -> Dict[str, Dict[str, list]]:
        """Load extra descriptions from JSON file."""
        if not self.extra_desc_json_path.exists():
            raise FileNotFoundError(f"Extra descriptions file not found: {self.extra_desc_json_path}")
        
        with open(self.extra_desc_json_path, 'r') as f:
            descriptions = json.load(f)
        
        print(f"Loaded extra descriptions for {len(descriptions)} tasks")
        return descriptions

    def _get_task_instruction(self, task_name: str) -> str:
        """Get a random instruction for the task based on current split."""
        # The task_name comes from _extract_task_description which converts filename to clean description
        # e.g., "door close" from "door-close-v3_demo.hdf5"
        # But JSON keys are the original task names like "door-close"
        
        # Convert task description back to hyphenated format for matching
        task_hyphenated = task_name.replace(" ", "-")
        
        # Find matching task in extra descriptions
        matching_task = None
        for json_task_name in self.extra_descriptions.keys():
            # Direct match or check if JSON key matches our task
            if json_task_name == task_hyphenated:
                matching_task = json_task_name
                break
            # Also try without version suffix
            if json_task_name.replace("-v3", "") == task_hyphenated.replace("-v3", ""):
                matching_task = json_task_name
                break
        
        if not matching_task:
            # Fallback to original task description if no match found
            print(f"Warning: No extra description found for task '{task_name}', using original")
            return task_name
        
        # Select instruction based on split
        split_key = "train" if self.split == "train" else "eval"
        
        if split_key in self.extra_descriptions[matching_task]:
            instructions = self.extra_descriptions[matching_task][split_key]
            if instructions:
                return random.choice(instructions)
        
        # Fallback to original task description
        print(f"Warning: No {split_key} instructions found for task '{matching_task}', using original")
        return task_name

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        """Get a single transition with replaced language instruction."""
        # Get the original item
        item = super().__getitem__(idx)
        
        # Replace the task description with extra description
        original_task_desc = item["task_descriptions"]
        new_task_desc = self._get_task_instruction(original_task_desc)
        item["task_descriptions"] = new_task_desc
        
        return item


class MetaWorldExtraDescPerTaskDataset(MetaWorldPerTaskDataset):
    """Meta-World per-task dataset with extra descriptions from JSON file."""

    def __init__(self, config: MetaWorldExtraDescPerTaskDataConfig):
        self.extra_desc_json_path = Path(config.extra_desc_json_path)
        self.extra_descriptions = self._load_extra_descriptions()
        
        # Initialize parent class
        super().__init__(config)

    def _load_extra_descriptions(self) -> Dict[str, Dict[str, list]]:
        """Load extra descriptions from JSON file."""
        if not self.extra_desc_json_path.exists():
            raise FileNotFoundError(f"Extra descriptions file not found: {self.extra_desc_json_path}")
        
        with open(self.extra_desc_json_path, 'r') as f:
            descriptions = json.load(f)
        
        print(f"Loaded extra descriptions for {len(descriptions)} tasks")
        return descriptions

    def _get_task_instruction(self, task_name: str) -> str:
        """Get a random instruction for the task based on current split."""
        # For per-task dataset, we know the exact task name
        task_hyphenated = self.task_name
        
        # Find matching task in extra descriptions
        matching_task = None
        for json_task_name in self.extra_descriptions.keys():
            # Direct match or check if JSON key matches our task
            if json_task_name == task_hyphenated:
                matching_task = json_task_name
                break
            # Also try without version suffix
            if json_task_name.replace("-v3", "") == task_hyphenated.replace("-v3", ""):
                matching_task = json_task_name
                break
        
        split_key = "train" if self.split == "train" else "eval"
        if not matching_task:
            # Fallback to original task description if no match found
            print(f"Warning: No extra description found for task '{task_name}', using original")
            return task_name
        
        # Select instruction based on split
        if split_key in self.extra_descriptions[matching_task]:
            instructions = self.extra_descriptions[matching_task][split_key]
            if instructions:
                return random.choice(instructions)
        
        # Fallback to original task description
        print(f"Warning: No {split_key} instructions found for task '{matching_task}', using original")
        return task_name

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        """Get a single transition with replaced language instruction."""
        # Get the original item
        item = super().__getitem__(idx)
        
        # Replace the task description with extra description
        original_task_desc = item["task_descriptions"]
        new_task_desc = self._get_task_instruction(original_task_desc)
        item["task_descriptions"] = new_task_desc.lower()
        
        return item 