import json
import pandas as pd
import numpy as np
from typing import List, Dict, Any, Optional, Tuple
from datasets import load_dataset, Dataset
import os
from pathlib import Path
import logging
import random

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class DataLoader:
    """数据加载器，支持多种数据集格式"""
    
    def __init__(self, cache_dir: str = "./data/cache"):
        """
        初始化数据加载器
        
        Args:
            cache_dir: 缓存目录
        """
        self.cache_dir = Path(cache_dir)
        self.cache_dir.mkdir(parents=True, exist_ok=True)

    def load_alpaca_eval(self, data_dir: str, split: str = "eval", max_samples: Optional[int] = None) -> List[Dict[str, Any]]:
        """
        加载AlpacaEval数据集（按照官方推荐方式）
        
        Args:
            data_dir: 数据集目录
            split: 数据集分割
            max_samples: 最大样本数
            
        Returns:
            标准化的数据列表
        """
        logger.info(f"从 {data_dir} 加载AlpacaEval数据集 ({split})")
        
        try:
            # Assumes data_dir contains json files that can be loaded.
            # Let's try loading a specific file if it exists.
            json_path = os.path.join(data_dir, "alpaca_eval.json")
            if not os.path.exists(json_path):
                 logger.warning(f"{json_path} does not exist. Trying to load from directory.")
                 eval_set = load_dataset(data_dir, split=split, cache_dir=self.cache_dir)
            else:
                eval_set = load_dataset("json", data_files=json_path, split="train", cache_dir=self.cache_dir)

            if eval_set is None:
                raise ValueError("Failed to load eval set")
            
            logger.info(f"成功加载AlpacaEval数据集，共 {len(eval_set)} 个样本")
            
            # 限制样本数量
            if max_samples and len(eval_set) > max_samples:
                logger.info(f"限制样本数量从 {len(eval_set)} 到 {max_samples}")
                indices = list(range(len(eval_set)))
                random.shuffle(indices)
                eval_set = eval_set.select(indices[:max_samples])
            
            data = []
            for example in eval_set:
                # 从instruction中提取问题
                question = example.get('instruction', '')
                if example.get('input', ''):
                    question += f"\n\nInput: {example['input']}"
                
                data.append({
                    'question': question,
                    'source': 'alpaca_eval',
                    'category': 'instruction_following',
                    'original_data': example
                })
            
            logger.info(f"✅ 成功加载了 {len(data)} 个AlpacaEval样本")
            return data
            
        except Exception as e:
            logger.error(f"加载AlpacaEval失败: {e}")
            return []

    def load_mmlu(self, data_dir: str, subject: str = "all", max_samples: Optional[int] = None) -> List[Dict[str, Any]]:
        """
        加载MMLU数据集
        
        Args:
            data_dir: 数据集目录
            subject: 学科名称 ("all" 表示所有学科)
            max_samples: 最大样本数
            
        Returns:
            标准化的数据列表
        """
        logger.info(f"从 {data_dir} 加载MMLU数据集 (学科: {subject})")
        
        try:
            if subject == "all":
                # 加载一些主要学科
                subjects = ['astronomy', 'auxiliary_train', 'business_ethics', 'clinical_knowledge', 'college_biology', 'college_chemistry', 'college_computer_science', 'college_mathematics', 'college_medicine', 'college_physics', 'computer_security']
            else:
                subjects = [subject]
            
            all_data = []
            
            for subj in subjects:
                try:
                    # MMLU from local directory has a different structure.
                    # We expect data_dir/subj/test.csv
                    subject_path = Path(data_dir) / subj
                    if not subject_path.exists():
                        logger.warning(f"MMLU学科路径不存在: {subject_path}, 跳过")
                        continue
                        
                    dataset = load_dataset(str(subject_path), split="test", cache_dir=self.cache_dir)
                    
                    for item in dataset:
                        question = f"{item['question']}\n\nA. {item['choices'][0]}\nB. {item['choices'][1]}\nC. {item['choices'][2]}\nD. {item['choices'][3]}"
                        
                        all_data.append({
                            'question': question,
                            'answer': item['answer'],
                            'source': 'mmlu',
                            'category': 'knowledge',
                            'subject': subj
                        })
                        
                        if max_samples and len(all_data) >= max_samples:
                            break
                            
                except Exception as e:
                    logger.warning(f"无法加载MMLU学科 {subj} from {data_dir}: {e}")
                    continue
                
                if max_samples and len(all_data) >= max_samples:
                    break
            
            if max_samples:
                all_data = all_data[:max_samples]
            
            logger.info(f"加载了 {len(all_data)} 个MMLU样本")
            return all_data
            
        except Exception as e:
            logger.error(f"加载MMLU失败: {e}")
            return []
    
    def load_hellaswag(self, data_dir: str, split: str = "validation", max_samples: Optional[int] = None) -> List[Dict[str, Any]]:
        """
        加载HellaSwag数据集（包含多种加载策略）
        
        Args:
            data_dir: 数据集目录
            split: 数据集分割
            max_samples: 最大样本数
            
        Returns:
            标准化的数据列表
        """
        logger.info(f"从 {data_dir} 加载HellaSwag数据集 ({split})")
        
        try:
            dataset = load_dataset(data_dir, split=split, cache_dir=self.cache_dir)
            
            if dataset is None:
                raise ValueError("load_dataset returned None")
            
            logger.info(f"成功加载数据集，共 {len(dataset)} 个样本")
            
            if max_samples and len(dataset) > max_samples:
                logger.info(f"限制样本数量从 {len(dataset)} 到 {max_samples}")
                indices = list(range(len(dataset)))
                random.shuffle(indices)
                dataset = dataset.select(indices[:max_samples])
            
            data = []
            for i, item in enumerate(dataset):
                try:
                    if not item or 'ctx_a' not in item or 'endings' not in item or 'label' not in item:
                        logger.warning(f"样本 {i} 缺少必要字段，跳过")
                        continue
                    
                    ctx_a = item.get('ctx_a', '')
                    ctx_b = item.get('ctx_b', '')
                    context = ctx_a + (' ' + ctx_b if ctx_b else '')
                    
                    question = f"Context: {context}\n\nWhich ending makes the most sense?\n"
                    
                    endings = item.get('endings', [])
                    if not endings:
                        logger.warning(f"样本 {i} 没有endings，跳过")
                        continue
                        
                    for j, ending in enumerate(endings):
                        if ending:
                            question += f"{chr(65+j)}. {ending}\n"
                    
                    data.append({
                        'question': question,
                        'answer': item.get('label', 0),
                        'source': 'hellaswag',
                        'category': 'commonsense_reasoning'
                    })
                    
                except Exception as item_error:
                    logger.warning(f"处理样本 {i} 时出错: {item_error}")
                    continue
            
            logger.info(f"成功处理了 {len(data)} 个HellaSwag样本")
            return data
            
        except Exception as e:
            logger.error(f"加载HellaSwag失败: {e}")
            return []
 
    def load_truthfulqa(self, data_dir: str, split: str = "validation", max_samples: Optional[int] = None) -> List[Dict[str, Any]]:
        """
        加载TruthfulQA数据集
        
        Args:
            data_dir: 数据集目录
            split: 数据集分割
            max_samples: 最大样本数
            
        Returns:
            标准化的数据列表
        """
        logger.info(f"从 {data_dir} 加载TruthfulQA数据集 ({split})")
        
        try:
            dataset = load_dataset(os.path.join(data_dir, "generation"), split=split, cache_dir=self.cache_dir)
            
            if max_samples:
                indices = list(range(len(dataset)))
                random.shuffle(indices)
                dataset = dataset.select(indices[:min(max_samples, len(dataset))])
            
            data = []
            for item in dataset:
                data.append({
                    'question': item['question'],
                    'best_answer': item['best_answer'],
                    'correct_answers': item['correct_answers'],
                    'incorrect_answers': item['incorrect_answers'],
                    'source': 'truthfulqa',
                    'category': 'truthfulness'
                })
            
            logger.info(f"加载了 {len(data)} 个样本")
            return data
            
        except Exception as e:
            logger.error(f"加载TruthfulQA失败: {e}")
            return []
    
    def load_gsm8k(self, data_dir: str, split: str = "test", max_samples: Optional[int] = None) -> List[Dict[str, Any]]:
        """
        加载GSM8K数学推理数据集
        
        Args:
            data_dir: 数据集目录
            split: 数据集分割
            max_samples: 最大样本数
            
        Returns:
            标准化的数据列表
        """
        logger.info(f"从 {data_dir} 加载GSM8K数据集 ({split})")
        
        try:
            dataset = load_dataset(data_dir, "main", split=split, cache_dir=self.cache_dir)
            
            if max_samples:
                indices = list(range(len(dataset)))
                random.shuffle(indices)
                dataset = dataset.select(indices[:min(max_samples, len(dataset))])
            
            data = []
            for item in dataset:
                data.append({
                    'question': item['question'],
                    'answer': item['answer'],
                    'source': 'gsm8k',
                    'category': 'math_reasoning'
                })
            
            logger.info(f"加载了 {len(data)} 个样本")
            return data
            
        except Exception as e:
            logger.error(f"加载GSM8K失败: {e}")
            return []
    
    def load_code_eval(self, data_dir: str, split: str = "test", max_samples: Optional[int] = None) -> List[Dict[str, Any]]:
        """
        加载HumanEval代码评估数据集
        
        Args:
            data_dir: 数据集目录
            split: 数据集分割
            max_samples: 最大样本数
            
        Returns:
            标准化的数据列表
        """
        logger.info(f"从 {data_dir} 加载HumanEval数据集 ({split})")
        
        try:
            dataset = load_dataset(data_dir, split=split, cache_dir=self.cache_dir)
            
            if max_samples:
                indices = list(range(len(dataset)))
                random.shuffle(indices)
                dataset = dataset.select(indices[:min(max_samples, len(dataset))])

            data = []
            for item in dataset:
                question = f"Complete the following Python function:\n\n```python\n{item['prompt']}\n```\n\n{item['docstring'] if 'docstring' in item else ''}"
                
                data.append({
                    'question': question,
                    'canonical_solution': item['canonical_solution'],
                    'test': item['test'],
                    'source': 'humaneval',
                    'category': 'code_generation'
                })
            
            logger.info(f"加载了 {len(data)} 个样本")
            return data
            
        except Exception as e:
            logger.error(f"加载HumanEval失败: {e}")
            return []

    def save_data(self, data: List[Dict[str, Any]], file_path: str):
        """
        保存数据到文件
        
        Args:
            data: 要保存的数据
            file_path: 保存路径
        """
        file_path = Path(file_path)
        file_path.parent.mkdir(parents=True, exist_ok=True)
        
        if file_path.suffix == '.json':
            with open(file_path, 'w', encoding='utf-8') as f:
                json.dump(data, f, ensure_ascii=False, indent=2)
        elif file_path.suffix == '.csv':
            df = pd.DataFrame(data)
            df.to_csv(file_path, index=False)
        else:
            raise ValueError(f"不支持的文件格式: {file_path.suffix}")
        
        logger.info(f"数据已保存到: {file_path}")

def aggregate_from_local_datasets(total_samples: int = 1000):
    """
    从 /mnt/data/dataset 中的数据集聚合一个问题集。
    """
    base_dir = Path("/mnt/data/dataset")
    output_file = "./aggregated_1000_questions_9.json"
    
    loader = DataLoader()
    
    datasets_to_load = {
        "alpacaeval": loader.load_alpaca_eval,
        "gsm8k": loader.load_gsm8k,
        "hellaswag": loader.load_hellaswag,
        "humaneval": loader.load_code_eval,
        "mmlu": loader.load_mmlu,
        "truthfulqa": loader.load_truthfulqa,
    }
    
    available_datasets = [d for d in datasets_to_load if (base_dir / d).exists()]
    if not available_datasets:
        logger.error(f"在 {base_dir} 中没有找到任何有效的数据集目录。")
        return
        
    logger.info(f"找到 {len(available_datasets)} 个数据集: {', '.join(available_datasets)}")
    
    samples_per_dataset = total_samples // len(available_datasets)
    
    all_sampled_questions = []
    
    for name in available_datasets:
        logger.info(f"--- 处理数据集: {name} ---")
        loader_func = datasets_to_load[name]
        data_dir = str(base_dir / name)
        
        # Load all data first, then sample
        full_data = loader_func(data_dir=data_dir)
        
        if not full_data:
            logger.warning(f"未能从 {name} 加载数据，跳过。")
            continue
        
        num_to_sample = min(len(full_data), samples_per_dataset)
        
        sampled_indices = random.sample(range(len(full_data)), num_to_sample)
        sampled_data = [full_data[i] for i in sampled_indices]
        
        all_sampled_questions.extend(sampled_data)
        logger.info(f"从 {name} 中抽样了 {len(sampled_data)} 个问题。")

    # If we have slightly less than total_samples, we can leave it.
    # Or we can top up from the largest dataset, but for now this is fine.
    
    logger.info(f"聚合后总共 {len(all_sampled_questions)} 个问题。")
    
    # Shuffle the final dataset
    random.shuffle(all_sampled_questions)
    
    # Trim to exactly total_samples if we have more
    if len(all_sampled_questions) > total_samples:
        all_sampled_questions = all_sampled_questions[:total_samples]

    logger.info(f"最终数据集大小: {len(all_sampled_questions)}")

    # Save the aggregated data
    loader.save_data(all_sampled_questions, str(output_file))

if __name__ == "__main__":
    aggregate_from_local_datasets(total_samples=1000)