import os
import sys
from global_config import PROJECT_ROOT
##
from logs.logger import logger
from modules.utils import is_en_lang, seed_everything, readjsonl, writejsonl, writejson, readjson

import json
import argparse
import copy

import random
from abc import ABC, abstractmethod
from tqdm import tqdm
import re
from langdetect import detect

class BaseDataProcessor(ABC):
    """Base class for all data processors"""
    def __init__(self, input_path: str, sample_size: int = None):
        self.input_path = input_path
        
        self.sample_size = sample_size
        self.source = os.path.basename(input_path).split('.')[0]

    def read_data(self) -> List[Dict]:
        """Read data from input file"""
        if self.input_path.endswith('.jsonl'):
            return readjsonl(self.input_path)
        else:
            data = readjson(self.input_path)
            if isinstance(data, dict):
                return data.get("data", []) if "data" in data else [data]
            return data

    @abstractmethod
    def process_item(self, item: Dict) -> Dict:
        """Process a single item from the dataset"""
        pass

    # def is_english(self, text: str) -> bool:
    #     """Check if text is primarily English by examining character distribution"""
    #     if not text or len(text.strip()) == 0:
    #         return False
            
    #     # Count English letters and spaces
    #     english_chars = sum(1 for c in text if c.isalpha() and ord(c) < 128)
    #     spaces = sum(1 for c in text if c.isspace())
        
    #     # Get total length excluding spaces
    #     total_chars = len(text) - spaces
        
    #     if total_chars == 0:
    #         return False
            
    #     # Calculate ratio of English letters to total non-space characters
    #     english_ratio = english_chars / total_chars
        
    #     # Text is considered English if at least 30% of non-space chars are English letters
    #     return english_ratio >= 0.3
    
    def is_english(self, text: str) -> bool:
        try:
            if 'en' in detect(text):
                return True
            return False
        except:
            return False

    def process(self):
        """Process the entire dataset"""
        logger.info(f"Processing {self.source}...")
        data = self.read_data()
        # for efficiency
        random.shuffle( data )
        data = data[:self.sample_size * 3]
        
        processed_data = []
        
        for item in tqdm(data, desc=f"Processing {self.source} items"):
            processed = self.process_item(item)
            if processed and "prompt" in processed and "response" in processed:
                # Check if both prompt and response are in English
                if self.is_english(processed["prompt"]) and self.is_english(processed["response"]):
                    processed_data.append(processed)

        if self.sample_size and len(processed_data) < self.sample_size:
            logger.warning( f'self.source: too many instructions filtered' )
            self.sample_size = len(processed_data)
        processed_data = random.choices(processed_data, k=self.sample_size)
        logger.info( f'self.source: {self.sample_size}' )
        return processed_data

class OpenHermesProcessor(BaseDataProcessor):
    def process_item(self, item: Dict) -> Dict:
        if "conversations" not in item:
            return None
        
        result = {"source": item["source"] if "source" in item else self.source}
        messages = item["conversations"]
        
        # 只取第一轮对话
        human_msg = next((msg for msg in messages if msg.get("from") == "human"), None)
        assistant_msg = next((msg for msg in messages if msg.get("from") == "gpt"), None)
        
        if human_msg and assistant_msg:
            result["prompt"] = human_msg.get("value", "")
            result["response"] = assistant_msg.get("value", "")
            result["id"] = f"{self.source}_{hash(result['prompt'] + result['response'])}"
            result["original_category"] = item.get("category",None)
            return result
        return None

class ShareGPTProcessor(BaseDataProcessor):
    def process_item(self, item: Dict) -> Dict:
        if "conversations" not in item:
            return None
            
        result = {"source": item["source"] if "source" in item else self.source}
        messages = item["conversations"]
        
        # # ShareGPT可能包含额外的元数据
        # if "metadata" in item:
        #     result["metadata"] = item["metadata"]
            
        human_msg = next((msg for msg in messages if msg.get("from") == "human"), None)
        assistant_msg = next((msg for msg in messages if msg.get("from") == "gpt"), None)
        
        if human_msg and assistant_msg:
            result["prompt"] = human_msg.get("value", "")
            result["response"] = assistant_msg.get("value", "")
            result["id"] = f"{self.source}_{hash(result['prompt'] + result['response'])}"
            result["original_category"] = item["category"] if "category" in item else None
            return result
        return None

class AlpacaProcessor(BaseDataProcessor):
    def process_item(self, item: Dict) -> Dict:
        if "messages" not in item:
            return None
            
        result = {"source": item["source"] if "source" in item else self.source}
        messages = item["messages"]
        
        human_msg = next((msg for msg in messages if msg.get("role") == "user"), None)
        assistant_msg = next((msg for msg in messages if msg.get("role") == "assistant"), None)
        
        if human_msg and assistant_msg:
            result["prompt"] = human_msg.get("content", "")
            result["response"] = assistant_msg.get("content", "")
            result["id"] = f"{self.source}_{hash(result['prompt'] + result['response'])}"
            result["original_category"] = item["category"] if "category" in item else None
            return result
        return None

class WizardLMProcessor(BaseDataProcessor):
    def process_item(self, item: Dict) -> Dict:
        if "messages" not in item:
            return None
            
        result = {"source": item["source"] if "source" in item else self.source}
        messages = item["messages"]
        
        human_msg = next((msg for msg in messages if msg.get("role") == "user"), None)
        assistant_msg = next((msg for msg in messages if msg.get("role") == "assistant"), None)
        
        if human_msg and assistant_msg:
            result["prompt"] = human_msg.get("content", "")
            result["response"] = assistant_msg.get("content", "")
            result["id"] = f"{self.source}_{hash(result['prompt'] + result['response'])}"
            result["original_category"] = item["category"] if "category" in item else None
            return result
        return None

class OrcaChatProcessor(BaseDataProcessor):
    def process_item(self, item: Dict) -> Dict:
        if "messages" not in item:
            return None
            
        result = {"source": item["source"] if "source" in item else self.source}
        messages = item["messages"]
        
        human_msg = next((msg for msg in messages if msg.get("role") == "user"), None)
        assistant_msg = next((msg for msg in messages if msg.get("role") == "assistant"), None)
        
        if human_msg and assistant_msg:
            result["prompt"] = human_msg.get("content", "")
            result["response"] = assistant_msg.get("content", "")
            result["id"] = f"{self.source}_{hash(result['prompt'] + result['response'])}"
            result["original_category"] = item["category"] if "category" in item else None
            return result
        return None

class Oasst2Processor(BaseDataProcessor):
    def process_item(self, item: Dict) -> Dict:
        if "messages" not in item:
            return None
            
        result = {"source": item["source"] if "source" in item else self.source}
        messages = item["messages"]
        
        human_msg = next((msg for msg in messages if msg.get("role") == "user"), None)
        assistant_msg = next((msg for msg in messages if msg.get("role") == "assistant"), None)
        
        if human_msg and assistant_msg:
            result["prompt"] = human_msg.get("content", "")
            result["response"] = assistant_msg.get("content", "")
            result["id"] = f"{self.source}_{hash(result['prompt'] + result['response'])}"
            result["original_category"] = item["category"] if "category" in item else None
            return result
        return None

class NoRobotsProcessor(BaseDataProcessor):
    def process_item(self, item: Dict) -> Dict:
        if "messages" not in item:
            return None
            
        result = {"source": item["source"] if "source" in item else self.source}
        messages = item["messages"]
        
        human_msg = next((msg for msg in messages if msg.get("role") == "user"), None)
        assistant_msg = next((msg for msg in messages if msg.get("role") == "assistant"), None)
        
        if human_msg and assistant_msg:
            result["prompt"] = human_msg.get("content", "")
            result["response"] = assistant_msg.get("content", "")
            result["id"] = f"{self.source}_{hash(result['prompt'] + result['response'])}"
            result["original_category"] = item["category"] if "category" in item else None
            return result
        return None

class SupernaturalProcessor(BaseDataProcessor):
    def process_item(self, item: Dict) -> Dict:
        hash_id = hash(item.get("prompt", "") + item.get("response", ""))
        return {
            "source": self.source,
            "prompt": item.get("prompt", ""),
            "response": item.get("response", ""),
            "id" : f"{self.source}_{hash_id}",
            "original_category" :  item["category"] if "category" in item else None,
        }

def get_processor_class(source: str) -> type:
    """Get the appropriate processor class based on the source"""
    processors = {
        "openhermes2_5": OpenHermesProcessor,
        "supernatural": SupernaturalProcessor,
        "sharegpt_clean_en_reduce_rep": ShareGPTProcessor,
        "wizardLM": WizardLMProcessor,
        "orca_chat": OrcaChatProcessor,
        "alpaca": AlpacaProcessor,
        "oasst2": Oasst2Processor,
        "no_robots": NoRobotsProcessor
    }
    return processors.get(source, BaseDataProcessor)

def process_dataset(source: str, sample_size: int = None):
    """Process a single dataset"""
    input_dir = os.path.join(PROJECT_ROOT,"data/v0_raw")
    # Try json first, if not found try jsonl
    json_path = os.path.join(input_dir, f"{source}.json")
    jsonl_path = os.path.join(input_dir, f"{source}.jsonl")
    input_path = jsonl_path if os.path.exists(jsonl_path) else json_path
    
    processor_class = get_processor_class(source)
    processor = processor_class(input_path, sample_size)
    return processor.process()
### 500k

def process_all_datasets():
    """Process all datasets with their respective sample sizes"""
    sample_sizes = {
        "sharegpt_clean_en_reduce_rep": 51000, # 94145
        "no_robots": 19000, # 19000
        "openhermes2_5": 200000, #1001551
        "supernatural": 10000, #1990915
        "wizardLM": 25000, # 46034
        "orca_chat": 25000, #44885
        "alpaca": 15000, #15304
        "oasst2": 5000, #5236
    }
    # sample_sizes = {
    #     "sharegpt_clean_en_reduce_rep": 1,
    #     "no_robots": 1,
    #     "openhermes2_5": 1,
    #     "supernatural": 2,
    #     "wizardLM": 2,
    #     "orca_chat": 2,
    #     "alpaca": 2,
    #     "oasst2": 2,
    # }
    all_list = []
    
    for source, sample_size in tqdm(sample_sizes.items(), desc="Processing datasets"):
        ret_list = process_dataset(source, sample_size)
        all_list.extend(ret_list)  # Fixed bug: was extending all_list with itself
    writejsonl( all_list, os.path.join(PROJECT_ROOT, 'data/v1_seed/random_selected_data_350k.jsonl')  )   
if __name__ == "__main__":
    process_all_datasets()

