import re
from typing import Optional, Literal
from datasets import load_dataset

from .base import BaseDataset, TextSample


class CombinedDataset(BaseDataset):
    
    def __init__(self, datasets: list[BaseDataset], max_samples: Optional[int] = None):
        super().__init__(max_samples)
        self.datasets = datasets
    
    def _load_samples(self) -> list[TextSample]:
        all_samples = []
        for dataset in self.datasets:
            dataset.load()
            all_samples.extend(dataset._samples)
        return all_samples


class S1KDataset(BaseDataset):
    
    def __init__(
        self,
        max_samples: Optional[int] = None,
        trajectory_source: Literal["gemini", "deepseek", "both"] = "both",
        min_text_length: int = 100,
        max_text_length: int = 4000,
    ):
        super().__init__(max_samples)
        self.trajectory_source = trajectory_source
        self.min_text_length = min_text_length
        self.max_text_length = max_text_length
    
    def _load_samples(self) -> list[TextSample]:
        dataset = load_dataset(
            path="simplescaling/s1K-1.1",
            split="train",
        )
        
        samples = []
        
        for item in dataset:
            cot_type = item.get("cot_type", "unknown")
            
            trajectories = []
            if self.trajectory_source in ["gemini", "both"]:
                gemini_traj = item.get("gemini_thinking_trajectory", "")
                if gemini_traj and len(gemini_traj) >= self.min_text_length:
                    trajectories.append(("gemini", gemini_traj))
            
            if self.trajectory_source in ["deepseek", "both"]:
                deepseek_traj = item.get("deepseek_thinking_trajectory", "")
                if deepseek_traj and len(deepseek_traj) >= self.min_text_length:
                    trajectories.append(("deepseek", deepseek_traj))
            
            for source, text in trajectories:
                if len(text) > self.max_text_length:
                    text = text[:self.max_text_length]
                
                samples.append(TextSample(
                    text=text,
                    is_reasoning=True,
                    source=f"s1k_{source}",
                    metadata={
                        "cot_type": cot_type,
                        "trajectory_source": source,
                        "question": item.get("question", ""),
                    },
                ))
        
        return samples


class GeneralInquiryCoTDataset(BaseDataset):
    
    def __init__(
        self,
        max_samples: Optional[int] = None,
        extract_thinking: bool = True,
        min_text_length: int = 100,
        max_text_length: int = 4000,
    ):
        super().__init__(max_samples)
        self.extract_thinking = extract_thinking
        self.min_text_length = min_text_length
        self.max_text_length = max_text_length
    
    def _extract_thinking_content(self, text: str) -> str:
        pattern = r"<think>(.*?)</think>"
        match = re.search(pattern, text, re.DOTALL)
        if match:
            return match.group(1).strip()
        return text
    
    def _load_samples(self) -> list[TextSample]:
        dataset = load_dataset(
            path="moremilk/General_Inquiry_Thinking-Chain-Of-Thought",
            split="train",
        )
        
        samples = []
        
        for item in dataset:
            metadata = item.get("metadata", {})
            reasoning = metadata.get("reasoning", "")
            
            if not reasoning:
                continue
            
            if self.extract_thinking:
                text = self._extract_thinking_content(reasoning)
            else:
                text = reasoning
            
            if len(text) < self.min_text_length:
                continue
            
            if len(text) > self.max_text_length:
                text = text[:self.max_text_length]
            
            samples.append(TextSample(
                text=text,
                is_reasoning=True,
                source="general_inquiry_cot",
                metadata={
                    "difficulty": metadata.get("difficulty"),
                    "topic": metadata.get("topic"),
                    "question": item.get("question", ""),
                },
            ))
        
        return samples


def get_reasoning_dataset(
    name: Literal["s1k", "general_inquiry_cot", "combined"],
    max_samples: Optional[int] = None,
    **kwargs,
) -> BaseDataset:
    if name == "s1k":
        return S1KDataset(max_samples=max_samples, **kwargs)
    elif name == "general_inquiry_cot":
        return GeneralInquiryCoTDataset(max_samples=max_samples, **kwargs)
    elif name == "combined":
        s1k = S1KDataset(max_samples=max_samples // 2 if max_samples else None, **kwargs)
        giq = GeneralInquiryCoTDataset(max_samples=max_samples // 2 if max_samples else None, **kwargs)
        return CombinedDataset([s1k, giq], max_samples=max_samples)
    else:
        raise ValueError(f"Unknown reasoning dataset: {name}")
