import os
import json
from tqdm import tqdm
from typing import List, Dict, Tuple, Optional
import argparse
import re
from datetime import datetime

                   
from llama_index.core import Document
from llama_index.core.schema import BaseNode
from llama_index.core.node_parser import (
    TokenTextSplitter,
)

           
from utils.txt_reader import TxtReader
from utils.ner_extractor import GameNERExtractor


class TimeSegmentCorpusBuilder:
                       

    def __init__(self, segments_file: str, data_directory: str, output_dir: str,
                 chunk_size: int = 1000, chunk_overlap: int = 200, clean_html: bool = True,
                 enable_ner: bool = False, openai_api_key: Optional[str] = None,
                 openai_base_url: Optional[str] = None, ner_model: str = "gpt-4o",
                 use_self_icl: bool = True, num_pseudo_examples: int = 1):
           
        self.segments_file = segments_file
        self.data_directory = data_directory
        self.output_dir = output_dir
        self.chunk_size = chunk_size
        self.chunk_overlap = chunk_overlap
        self.clean_html = clean_html
        self.enable_ner = enable_ner
        self.use_self_icl = use_self_icl
        self.num_pseudo_examples = num_pseudo_examples

                  
        self.segments = self._load_segments()

                 
        self.splitter = TokenTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)

                  
        self.txt_reader = TxtReader(clean_html=clean_html)

                  
        if self.enable_ner:
            self.ner_extractor = GameNERExtractor(
                api_key=openai_api_key,
                base_url=openai_base_url,
                model=ner_model,
                use_self_icl=use_self_icl,
                num_pseudo_examples=num_pseudo_examples,
                max_workers=3            
            )
            method_type = "SELF-ICL" if use_self_icl else "Traditional"
            if use_self_icl:
                print("NER: SELF-ICL")
        else:
            self.ner_extractor = None

    def _load_segments(self) -> List[Dict]:
                      
        try:
            with open(self.segments_file, 'r', encoding='utf-8') as f:
                data = json.load(f)
            return data.get('segments', [])
        except Exception as e:
            return []

    def _extract_date_from_filename(self, filename: str) -> Optional[str]:
                       
        date_pattern = r'^(\d{4}-\d{2}-\d{2})'
        match = re.match(date_pattern, filename)
        if match:
            return match.group(1)
        return None

    def _get_segment_for_date(self, date_str: str) -> Optional[int]:
      
        if not date_str:
            return None

        try:
            date_obj = datetime.strptime(date_str, '%Y-%m-%d')

            for segment in self.segments:
                start_date = datetime.strptime(segment['start_date'], '%Y-%m-%d')
                end_date = datetime.strptime(segment['end_date'], '%Y-%m-%d')

                                         
                if start_date <= date_obj < end_date:
                    return segment['segment_id']

        except ValueError as e:
            print(f"Error getting segment for date {date_str}: {e}")

        return None

    def _flatten_folder(self, root: str) -> List[str]:
           
        files = []
        for item in os.listdir(root):
            path = os.path.join(root, item)
            if os.path.isdir(path):
                files.extend(self._flatten_folder(path))
            elif path.endswith(('.txt', '.md')):                     
                files.append(path)
        return files

    def _categorize_files_by_segment(self) -> Dict[int, List[str]]:           
                  
        all_files = self._flatten_folder(self.data_directory)

                 
        categorized_files = {-1: []}               

        for segment in self.segments:
            categorized_files[segment['segment_id']] = []

              
        for file_path in all_files:
            filename = os.path.basename(file_path)
            file_date = self._extract_date_from_filename(filename)

            if file_date:
                segment_id = self._get_segment_for_date(file_date)
                if segment_id is not None:
                    categorized_files[segment_id].append(file_path)
                else:
                    categorized_files[-1].append(file_path)            
            else:
                        
                categorized_files[-1].append(file_path)

                
        valid_segments = sum(1 for files in categorized_files.values() if len(files) > 0)

        return categorized_files

    def _process_segment_files(
            self, files: List[str], segment_id: int) -> Tuple[List[Document], List[BaseNode], Optional[Dict]]:
                
        all_documents = []
        for file_path in tqdm(files, desc=f"Loading segment {segment_id} documents"):
                             
            extra_info = {'segment_id': segment_id}
            documents = self.txt_reader.load_data(file_path, extra_info=extra_info)
            all_documents.extend(documents)

               
        valid_documents = []
        empty_count = 0
        for doc in all_documents:
            if doc.get_content().strip() == '':
                empty_count += 1
                continue
            valid_documents.append(doc)

        if empty_count > 0:
            print(f"Empty documents: {empty_count}")
                 
        if valid_documents:
            nodes = self.splitter.get_nodes_from_documents(valid_documents, show_progress=False)
        else:
            nodes = []

                 
        ner_stats = None
        if self.enable_ner and self.ner_extractor and nodes:

                         
            node_texts = [node.get_content() for node in nodes]

                       
            all_entities = self.ner_extractor.batch_extract_entities(
                node_texts, show_progress=True
            )

                            
            for node, entities in zip(nodes, all_entities):
                node.metadata['entities'] = entities
                node.metadata['entity_count'] = len(entities)

                       
            ner_stats = self.ner_extractor.get_entity_statistics(all_entities)

        return valid_documents, nodes, ner_stats

    def _save_segment_corpus(self, documents: List[Document], nodes: List[BaseNode],
                             segment_id: int, segment_info: Optional[Dict] = None, ner_stats: Optional[Dict] = None):
        print(f"Saving segment corpus for segment {segment_id}")
        if segment_id == -1:
            output_subdir = os.path.join(self.output_dir, 'segment_timeless')
        else:
            output_subdir = os.path.join(self.output_dir, f'segment_{segment_id}')

        os.makedirs(output_subdir, exist_ok=True)

                
        if segment_info:
                                 
            if ner_stats:
                segment_info['ner_stats'] = ner_stats

            info_file = os.path.join(output_subdir, 'segment_info.json')
            with open(info_file, 'w', encoding='utf-8') as f:
                json.dump(segment_info, f, ensure_ascii=False, indent=2)

                     
        if ner_stats:
            ner_file = os.path.join(output_subdir, 'ner_stats.json')
            with open(ner_file, 'w', encoding='utf-8') as f:
                json.dump(ner_stats, f, ensure_ascii=False, indent=2)

                
        docs_file = os.path.join(output_subdir, 'documents.json')
        with open(docs_file, 'w', encoding='utf-8') as f:
            docs_data = []
            for doc in documents:
                docs_data.append({
                    'id': doc.id_,
                    'text': doc.get_content(),
                    'metadata': doc.metadata
                })
            json.dump(docs_data, f, ensure_ascii=False, indent=2)

                  
        nodes_file = os.path.join(output_subdir, 'nodes.json')
        with open(nodes_file, 'w', encoding='utf-8') as f:
            nodes_data = []
            for node in nodes:
                nodes_data.append({
                    'id': node.id_,
                    'text': node.get_content(),
                    'metadata': node.metadata
                })
            json.dump(nodes_data, f, ensure_ascii=False, indent=2)

                            
        corpus_file = os.path.join(output_subdir, 'corpus.jsonl')
        with open(corpus_file, 'w', encoding='utf-8') as f:
            for node in nodes:
                corpus_data = {
                    'id': node.id_,
                    'title': node.metadata.get('title', ''),
                    'contents': node.get_content(),
                    'metadata': node.metadata
                }
                f.write(json.dumps(corpus_data, ensure_ascii=False) + '\n')

                
        segment_name = f"Segment {segment_id}" if segment_id != -1 else "Timeless files"

    def build_time_segmented_corpus(self):
                           

                   
        categorized_files = self._categorize_files_by_segment()

                
        os.makedirs(self.output_dir, exist_ok=True)

                  
        overall_stats = {
            'total_segments': len(self.segments),
            'total_files': sum(len(files) for files in categorized_files.values()),
            'processing_config': {
                'chunk_size': self.chunk_size,
                'chunk_overlap': self.chunk_overlap,
                'clean_html': self.clean_html
            },
            'segments': []
        }

                  
        for segment_id, files in categorized_files.items():
            if not files:         
                continue

                    
            if segment_id == -1:
                segment_info = {
                    'segment_id': -1,
                    'description': 'Timeless files',
                    'file_count': len(files)
                }
            else:
                segment_info = next((s for s in self.segments if s['segment_id'] == segment_id), None)
                if segment_info:
                    segment_info = segment_info.copy()
                    segment_info['file_count'] = len(files)

                  
            documents, nodes, ner_stats = self._process_segment_files(files, segment_id)

                   
            self._save_segment_corpus(documents, nodes, segment_id, segment_info, ner_stats)

                    
            if segment_info:
                segment_info['document_count'] = len(documents)
                segment_info['node_count'] = len(nodes)
                overall_stats['segments'].append(segment_info)

                  
        stats_file = os.path.join(self.output_dir, 'overall_stats.json')
        with open(stats_file, 'w', encoding='utf-8') as f:
            json.dump(overall_stats, f, ensure_ascii=False, indent=2)

        valid_segments_count = len([s for s in overall_stats['segments'] if s.get('document_count', 0) > 0])


def main():
             
    parser = argparse.ArgumentParser(description='Partition data corpus by time segments')
    parser.add_argument('--segments_file', type=str,
                        default='../data/pubgm/question_segments_results.json',
                        help='Time segment configuration file path')
    parser.add_argument('--data_dir', type=str,
                        default='../data/pubgm/knowledge',
                        help='Data directory path')
    parser.add_argument('--output_dir', type=str,
                        default='../data/pubgm/corpus',
                        help='Output directory path')
    parser.add_argument('--chunk_size', type=int, default=2048,
                        help='Text chunk size')
    parser.add_argument('--chunk_overlap', type=int, default=256,
                        help='Text chunk overlap size')
    parser.add_argument('--no_clean_html', action='store_true',
                        help='Disable HTML cleaning functionality')
    parser.add_argument('--enable_ner', action='store_true',
                        help='Enable NER entity recognition functionality')
    parser.add_argument('--openai_api_key', type=str, default="{your api_key}",
                        help='OpenAI API key')
    parser.add_argument('--openai_base_url', type=str, default="{your base_url}",
                        help='OpenAI API base URL')
    parser.add_argument('--ner_model', type=str, default='gpt-4o',
                        help='Model used for NER')
    parser.add_argument('--disable_self_icl', action='store_true',
                        help='Disable SELF-ICL technique, use traditional NER method')
    parser.add_argument('--num_pseudo_examples', type=int, default=1,
                        help='Number of pseudo examples generated by SELF-ICL')

    args = parser.parse_args()

                
    if not os.path.exists(args.segments_file):
        return

    if not os.path.exists(args.data_dir):
        return

    if args.enable_ner and not args.openai_api_key:
        print("Enable NER but no OpenAI API key provided")
        return
    
    builder = TimeSegmentCorpusBuilder(
        segments_file=args.segments_file,
        data_directory=args.data_dir,
        output_dir=args.output_dir,
        chunk_size=args.chunk_size,
        chunk_overlap=args.chunk_overlap,
        clean_html=not args.no_clean_html,
        enable_ner=args.enable_ner,
        openai_api_key=args.openai_api_key,
        openai_base_url=args.openai_base_url,
        ner_model=args.ner_model,
        use_self_icl=not args.disable_self_icl,
        num_pseudo_examples=args.num_pseudo_examples
    )

    builder.build_time_segmented_corpus()


if __name__ == '__main__':
    main()
