import os
import json
import argparse
from pathlib import Path
from typing import List, Dict, Any, Tuple
from transformers import AutoTokenizer
from src import utils,loader

def is_api_model(model_path:str):
    if model_path in ['deepseek-r1-250120','deepseek-r1-250528','doubao-1.5-thinking-pro-250415','doubao-1.5-pro-32k-250115','doubao-1.5-lite-32k-250115',' ','doubao-seed-1.6-250615','doubao-seed-1.6-thinking-250615','doubao-seed-1.6-flash-250615']:
        return True
    return False

class DistillationPipeline:
    def __init__(self, model_path: str, output_base_dir: str = "distillation_results",tensor_parallel_size=os.getenv('ARNOLD_WORKER_GPU',8)):
        self.model_path = model_path
        self.output_base_dir = Path(output_base_dir)
        self.output_dir = self.output_base_dir / Path(model_path).name
        self.model_name = Path(model_path).name
        self.output_dir.mkdir(parents=True, exist_ok=True)

        self.tokenizer=None
        self.model=None
        if not is_api_model(model_path):
            # load model 
            self.model = utils.load_model(self.model_path,tensor_parallel_size=int(tensor_parallel_size))
            self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)

        print(f"[DistillationPipeline] __init__  \nmodel_name: {self.model_name}\noutput_dir:{self.output_dir}\nmodel_path: {self.model_path}")
        
    def _get_output_path(self, dataset_name: str, chunk_index: int) -> Path:
        """生成输出文件路径，格式为: {output_dir}/{dataset_name}/chunk{index}.jsonl"""
        file_path=self.output_dir / dataset_name /f"chunk{chunk_index}.jsonl"
        file_path.parent.mkdir(parents=True, exist_ok=True)
        return file_path
    
    def _chunk_dataset(self, dataset: List[loader.Question], chunk_size: int = 500) -> List[List[loader.Question]]:
        """将数据集分成指定大小的chunk"""
        return [dataset[i:i + chunk_size] for i in range(0, len(dataset), chunk_size)]
    
    def _generate_answers(self, dataset_chunk:List[loader.QuestionQuery]) -> List[loader.QuestionOutput]:
        # 提取所有prompt
        prompts = [item.prompt for item in dataset_chunk]
        responses=[]
        if not is_api_model(self.model_path):
            # 调用vLLM生成回答
            llm_responses = utils.llm_inference(
                model=self.model,
                prompts=prompts,
                max_tokens=1024*16
            )
            for output in llm_responses:
                responses.append(output.outputs[0].text)
        else:
            responses = utils.multi_chat_completion_api(
                model=self.model_name,
                question_list=prompts
            )
        
        ret=[]
        # 合并结果到原始数据
        for item, response in zip(dataset_chunk, responses):
            ret.append(loader.QuestionOutput(
                id=item.id,
                origin_question=item.origin_question,
                origin_correct_answer=item.origin_correct_answer,
                prompt=item.prompt,
                generated_text=response
            ))
            
        return ret
    
    def _save_chunk(self, dataset_chunk: List[loader.QuestionOutput], output_path: Path):
        """保存JSONL格式的chunk"""
        print(f"_save_chunk {output_path},len(dataset_chunk):{len(dataset_chunk)}")
        json_list=[x.__json__() for x in dataset_chunk]
        print(type(json_list))
        utils.write_json_list(json_list,output_path)
    
    def _load_dataset(self, dataset_name: str,rep=1) -> List[loader.QuestionQuery]:
        dataset = []
        if dataset_name in loader.loaderFuncDict:
            func=loader.loaderFuncDict[dataset_name]
            question_list=func()
            for i in range(rep):
                prompt_list=loader.build_distill_prompt(self.tokenizer,[x.origin_question for x in question_list])
                for question,prompt in zip(question_list,prompt_list):
                    dataset.append(loader.QuestionQuery(
                        id=f"{question.id}_{i}",
                        origin_question=question.origin_question,
                        origin_correct_answer=question.correct_answer,
                        prompt=prompt
                    ))
        else:
            print(f"dataset_name {dataset_name} not in loaderFuncDict")

        print(f"load dataset_name {dataset_name},rep:{rep},{len(dataset)} questions")
        return dataset
    
    def process_dataset(self, dataset_name: str, retries: int = 3):
        """处理单个数据集"""
        
        # 加载完整数据集
        full_dataset = self._load_dataset(dataset_name)
        
        # 分块处理
        chunks = self._chunk_dataset(full_dataset)
        
        # 进度跟踪
        completed_chunks = 0
        total_chunks = len(chunks)
        
        for chunk_index, chunk_data in enumerate(chunks):
            output_path = self._get_output_path(dataset_name, chunk_index)
            
            # 重试机制 - 如果文件存在且非空，跳过处理
            if output_path.exists() and output_path.stat().st_size > 0:
                print(f"Chunk {chunk_index} already processed. Skipping...")
                completed_chunks += 1
                continue
                
            # 失败重试逻辑
            for attempt in range(retries + 1):
                try:
                    # 生成答案并保存
                    processed_chunk = self._generate_answers(chunk_data)
                    self._save_chunk(processed_chunk, output_path)
                    
                    print(f"Successfully processed chunk {chunk_index}")
                    completed_chunks += 1
                    break
                    
                except Exception as e:
                    print(f"Attempt {attempt} failed for chunk {chunk_index}: {str(e)}")
                    if attempt < retries:
                        print("Retrying...")
                    else:
                        print("Retries exhausted. Moving to next chunk.")
                        # 失败时删除可能不完整的文件
                        if output_path.exists():
                            os.remove(output_path)
                
            # 打印进度
            print(f"Progress: {completed_chunks}/{total_chunks} chunks completed")
            
        print(f"✅ Finished processing dataset: {dataset_name}")

def main():
    parser = argparse.ArgumentParser(description="Model Distillation Pipeline")
    parser.add_argument("--model_path", type=str, required=True, 
                        help="Path to the model directory or file")
    parser.add_argument("--dataset_list",type=str, required=True, 
                        help="List of dataset file paths")
    parser.add_argument("--output_dir", type=str, default="out_put/distillation_results",
                        help="Output directory for results")
    parser.add_argument("--retries", type=int, default=3,
                        help="Number of retries for failed chunks")
    args = parser.parse_args()
    dataset_list=args.dataset_list.split(',')
    print(f"dataset_list:{dataset_list}")
    print(args)
    pipeline = DistillationPipeline(
        model_path=args.model_path,
        output_base_dir=args.output_dir
    )
    
    for dataset_name in dataset_list:
        print(f"\n{'='*50}")
        print(f"Starting distillation for dataset: {dataset_name}")
        pipeline.process_dataset(
            dataset_name=dataset_name,
            retries=args.retries
        )

if __name__ == "__main__":
    main()