import os
import sys
sys.path.append(os.getcwd())
import json
import random
from typing import List, Dict, Tuple
import pandas as pd
import argparse
class Construct_Multi_Choice:
    def __init__(self, base_path: str):
        """
        初始化数据集构建器
        Args:
            base_path: 包含Article、Cover、Story和Other_Articles文件夹的根目录
        """
        self.base_path = base_path
        self.article_path = os.path.join(base_path, 'Article')
        self.cover_path = os.path.join(base_path, 'Cover')
        self.story_path = os.path.join(base_path, 'Story')
        self.other_articles_path = os.path.join(base_path, 'Other_Articles')
        
    def get_journals(self) -> List[str]:
        """获取所有期刊名称"""
        return [j for j in os.listdir(self.story_path) 
                if os.path.isdir(os.path.join(self.story_path, j))]
    
    def get_issues(self, journal: str) -> List[str]:
        """获取指定期刊的所有issue
        用cover path，因为所有可以得到cover的只需要之后检查是否有story和article就可以"""
        journal_path = os.path.join(self.cover_path, journal)
        return [i.split('.')[0] for i in os.listdir(journal_path) 
                if os.path.isfile(os.path.join(journal_path, i))]
    
    def read_file_content(self, path: str) -> str:
        """读取文件内容"""
        try:
            with open(path, 'r', encoding='utf-8') as f:
                return f.read().strip()
        except Exception as e:
            print(f"Error reading file {path}: {e}")
            return ""
            
    def construct_question(self, journal: str, issue: str) -> Dict:
        """
        为单个issue构建多选题
        Returns:
            Dict: {
                'id': 'journal_issue',
                'question': '根据Cover Story，以下哪篇文章最可能是被总结的文章？',
                'story': story_content,
                'options': [{'id': 'A', 'text': '...', 'abstract': '...'}, ...],
                'answer': 'A',
                'cover_image': cover_image_path
            }
        """
        # 读取story内容
        if not (os.path.exists(os.path.join(self.story_path, journal, f"{issue}.txt")) and os.path.exists(os.path.join(self.other_articles_path, journal, f"{issue}.json"))):
            return None
        story_path = os.path.join(self.story_path, journal, f"{issue}.txt")
        story_content = self.read_file_content(story_path)
        
        
        # 读取其他文章
        other_path = os.path.join(self.other_articles_path, journal, f"{issue}.json")
        try:
            with open(other_path, 'r') as f:
                other_articles = json.load(f)
        except Exception as e:
            print(f"Error loading other articles for {journal}/{issue}: {e}")
            other_articles = {}
            
        # 构建选项
        options = []
        # 添加正确选项
                # 创建所有可能的选项ID
        option_ids = ['A', 'B', 'C', 'D']
        random.shuffle(option_ids)  # 随机打乱选项顺序
        
        # 添加正确选项（使用第一个随机ID）
        options.append({
            'id': option_ids[0],
            'text': story_content,
            'is_correct': True
        })
                
        # 添加干扰选项（使用剩余的ID）
        distractors = list(other_articles.items())
        random.shuffle(distractors)
        for i, (url, abstract) in enumerate(distractors[:3]):  # 只需要3个干扰项
            options.append({
                'id': option_ids[i + 1],  # 使用剩余的ID
                'text': abstract,
                'is_correct': False
            })
        
        # 按选项ID排序（确保选项始终按A、B、C、D顺序显示）
        options.sort(key=lambda x: x['id'])
        
        # 获取正确答案的选项ID
        answer = [opt['id'] for opt in options if opt['is_correct']][0]
        
        # 构建问题
        question = {
            'journal': journal,
            'id': issue,
            'question': 'Which of the following options best describe the cover image?',
            'story': story_content,
            'options': options,
            'answer': answer,
            'cover_image': os.path.join(self.cover_path, journal, f"{issue}.png")
        }
        
        return question
    
    def construct_easy_question(self,journal:str,issue:str,other_stories:list):
        """
        为单个issue构建简单单选题
        """
        # 读取story内容
        # if not (os.path.exists(os.path.join(self.story_path, journal, f"{issue}.txt")) and os.path.exists(os.path.join(self.other_articles_path, journal, f"{issue}.json"))):
        #     return None
        if not os.path.exists(os.path.join(self.story_path, journal, f"{issue}.txt")):
            return None
        story_path = os.path.join(self.story_path, journal, f"{issue}.txt")
        story_content = self.read_file_content(story_path)
        
        
        # 读取其他文章
        # other_path = os.path.join(self.other_articles_path, journal, f"{issue}.json")
        # try:
        #     with open(other_path, 'r') as f:
        #         other_articles = json.load(f)
        # except Exception as e:
        #     print(f"Error loading other articles for {journal}/{issue}: {e}")
        #     other_articles = {}
            
        # 构建选项
        options = []
        # 添加正确选项
                # 创建所有可能的选项ID
        option_ids = ['A', 'B', 'C', 'D']
        random.shuffle(option_ids)  # 随机打乱选项顺序
        
        # 添加正确选项（使用第一个随机ID）
        options.append({
            'id': option_ids[0],
            'text': story_content,
            'is_correct': True
        })
                
        # 添加干扰选项（使用剩余的ID）
        distractors_path = [s for s in other_stories if s != f"{issue}.txt"]
        random.shuffle(distractors_path)
        # 将path转为content
        distractors = []
        for path in distractors_path[:3]:
            story_path = os.path.join(self.story_path, journal, path)
            story_content = self.read_file_content(story_path)
            distractors.append(story_content)
        if not len(distractors) == 3:
            print(f"Error: {journal}/{issue} has less than 3 distractors")
            return None
        for i, cover_story in enumerate(distractors):  # 只需要3个干扰项
            options.append({
                'id': option_ids[i + 1],  # 使用剩余的ID
                'text': cover_story,
                'is_correct': False
            })
        
        # 按选项ID排序（确保选项始终按A、B、C、D顺序显示）
        options.sort(key=lambda x: x['id'])
        
        # 获取正确答案的选项ID
        answer = [opt['id'] for opt in options if opt['is_correct']][0]
        
        # 构建问题
        question = {
            'journal': journal,
            'id': issue,
            'question': 'Which of the following options best describe the cover image?',
            'story': story_content,
            'options': options,
            'answer': answer,
            'cover_image': os.path.join(self.cover_path, journal, f"{issue}.png")
        }
        
        return question
    
    
    def check_dataset_integrity(self, dataset):
        """
        检查数据集的完整性
        
        Args:
            dataset: 包含问题的列表
            
        Returns:
            valid_dataset: 经过检查的有效数据集
            removed_count: 被移除的问题数量
        """
        valid_dataset = []
        removed_count = 0
        
        for question in dataset:
            is_valid = True
            
            # 检查cover image路径是否存在
            if not question['cover_image'] or not os.path.exists(question['cover_image']):
                print(f"the cover image of {question['journal']}/{question['id']} does not exist: {question['cover_image']}")
                is_valid = False
                
            # 检查选项完整性
            if 'option_A' not in question or not question['option_A'] or 'option_B' not in question or not question['option_B'] or 'option_C' not in question or not question['option_C'] or 'option_D' not in question or not question['option_D']:
                print(f"the options of {question['journal']}/{question['id']} does not exist")
                is_valid = False 
            # 检查是否有正确答案

            if 'answer' not in question or not question['answer']:
                print(f"no correct answer of {question['journal']}/{question['id']}")
                is_valid = False
            
            if is_valid:
                valid_dataset.append(question)
            else:
                removed_count += 1
                
        print(f"data set integrity check completed: total questions {len(dataset)}, valid questions {len(valid_dataset)}, removed questions {removed_count}")
        
        return valid_dataset, removed_count

    def construct_dataset(self, 
                        output_dir: str,
                        train_ratio: float = 0.7,
                        val_ratio: float = 0.15,
                        seed: int = 42) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
        """
        构建数据集并直接划分保存为CSV格式
        Args:
            output_dir: 输出目录路径
            train_ratio: 训练集比例
            val_ratio: 验证集比例
            seed: 随机种子
        Returns:
            Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: 训练集、验证集和测试集的DataFrame
        """
        # 确保输出目录存在
        if not os.path.exists(output_dir):
            os.makedirs(output_dir, exist_ok=True)
        
        # 构建数据集
        data_list = []
        journals = self.get_journals()
        

        for journal in journals:
            issues = self.get_issues(journal)
            # 获取同一期刊下所有的story文件
            all_stories = os.listdir(os.path.join(self.story_path, journal))
            # 排除当前issue的story
            other_stories = [s for s in all_stories] # 这里得到的是文件名

            # 🚧需要修改这里的逻辑，以排除简单单选题中自己的这个cover story

            for issue in issues:
                try:
                    # question = self.construct_question(journal, issue)
                    question = self.construct_easy_question(journal, issue, other_stories)
                    if question is None:
                        continue
                    # 将问题转换为扁平的字典格式
                    flat_data = {
                        'journal': question['journal'],
                        'id': question['id'],
                        'question': question['question'],
                        'cover_image': question['cover_image'],
                        'answer': question['answer'],
                    }
                    
                    # 添加选项
                    for i, opt in enumerate(question['options']):
                        flat_data[f'option_{opt["id"]}'] = opt["text"]
                    
                    data_list.append(flat_data)
                    
                except Exception as e:
                    print(f"Error processing {journal}/{issue}: {e}")
                    continue
        # 检查数据集完整性
        valid_dataset, removed_count = self.check_dataset_integrity(data_list)
        # 创建DataFrame
        df = pd.DataFrame(valid_dataset)
        

        # 划分数据集
        random.seed(seed)
        indices = list(range(len(df)))
        random.shuffle(indices)
        
        train_size = int(len(df) * train_ratio)
        val_size = int(len(df) * val_ratio)
        
        train_indices = indices[:train_size]
        val_indices = indices[train_size:train_size + val_size]
        test_indices = indices[train_size + val_size:]
        
        train_df = df.iloc[train_indices].copy()
        val_df = df.iloc[val_indices].copy()
        test_df = df.iloc[test_indices].copy()
        
        # 添加split标签
        train_df['split'] = 'train'
        val_df['split'] = 'val'
        test_df['split'] = 'test'
        
        # 保存为CSV
        splits = {
            'train': train_df,
            'val': val_df,
            'test': test_df
        }
        
        for split_name, split_df in splits.items():
            output_path = os.path.join(output_dir, f'{split_name}.csv')
            split_df.to_csv(output_path, index=False, encoding='utf-8')
            print(f"Saved {split_name} set ({len(split_df)} samples) to {output_path}")
        
        # 保存完整数据集
        full_df = pd.concat([train_df, val_df, test_df], axis=0)
        full_output_path = os.path.join(output_dir, 'full_dataset.csv')
        full_df.to_csv(full_output_path, index=False, encoding='utf-8')
        
        # 保存数据集统计信息
        stats = {
            'total_samples': len(df),
            'train_samples': len(train_df),
            'val_samples': len(val_df),
            'test_samples': len(test_df),
            'train_ratio': train_ratio,
            'val_ratio': val_ratio,
            'test_ratio': 1 - train_ratio - val_ratio,
            'journals': journals,
            'seed': seed,
            'columns': list(df.columns),
            'removed_count': removed_count
        }
        
        stats_path = os.path.join(output_dir, 'dataset_stats.json')
        with open(stats_path, 'w', encoding='utf-8') as f:
            json.dump(stats, f, ensure_ascii=False, indent=2)
        
        return train_df, val_df, test_df

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="construct multi-choice dataset")
    parser.add_argument('--output', type=str, default="./Data/Understanding/Nature", required=True, help="output directory")
    parser.add_argument('--base_path', type=str, default="./Nature", required=True, help="base path")
    args = parser.parse_args()
    construct = Construct_Multi_Choice(args.base_path)
    construct.construct_dataset(output_dir=args.output)