import json
import time
from typing import List, Dict, Optional
from concurrent.futures import ThreadPoolExecutor, as_completed
from openai import OpenAI


class GameNERExtractor:
                                     

    def __init__(self,
                 api_key: Optional[str] = None,
                 base_url: Optional[str] = None,
                 model: str = "gpt-3.5-turbo",
                 max_retries: int = 3,
                 retry_delay: float = 1.0,
                 use_self_icl: bool = True,
                 num_pseudo_examples: int = 3,
                 max_workers: int = 3):
        
        self.client = OpenAI(api_key=api_key, base_url=base_url) if api_key else None
        self.model = model
        self.max_retries = max_retries
        self.retry_delay = retry_delay
        self.use_self_icl = use_self_icl
        self.num_pseudo_examples = num_pseudo_examples
        self.max_workers = max_workers

                  
        self.entity_types = {
            "CHARACTER": "Characters, NPCs, players, protagonists, companions",
            "ITEM": "Items, equipment, weapons, tools, consumables, gear",
            "LOCATION": "Locations, places, areas, maps, regions, zones",
            "QUEST": "Quests, missions, tasks, objectives, storylines",
            "SKILL": "Skills, abilities, talents, powers, upgrades",
            "FACTION": "Factions, organizations, groups, guilds, clans",
            "MONSTER": "Monsters, enemies, bosses, creatures, opponents",
            "MECHANIC": "Game mechanics, gameplay systems, features, rules",
            "EVENT": "Events, story events, plot points, scenarios"
        }

    def _generate_pseudo_inputs(self, text: str) -> List[str]:     
        if not self.client:
            return []

        prompt = f"""You are tasked with generating pseudo-inputs for game-related Named Entity Recognition (NER).

Given the following game-related text, generate {self.num_pseudo_examples} similar but different game-related text examples that would be suitable for entity recognition. The generated texts should:
1. Be similar in style and domain to the input text
2. Contain various types of game entities
3. Be realistic and coherent
4. Have different specific entities but similar context patterns

Original text:
{text}...

Return the result in the following JSON format only, no other text:
{{
    "pseudo_inputs": [
        "example text 1",
        "example text 2",
        "example text 3"
    ]
}}

Generate exactly {self.num_pseudo_examples} pseudo-input examples."""

        try:
            response = self.client.chat.completions.create(
                model=self.model,
                messages=[
                    {"role": "system", "content": "You are an expert in generating game-related text examples for NER training. Return only valid JSON."},
                    {"role": "user", "content": prompt}
                ],
                temperature=0.7,
                max_tokens=800
            )

            response_text = response.choices[0].message.content.strip()
            result = self._parse_json_response(response_text)

            if isinstance(result, dict) and "pseudo_inputs" in result:
                pseudo_inputs = result["pseudo_inputs"]
                if isinstance(pseudo_inputs, list):
                    return pseudo_inputs[:self.num_pseudo_examples]

            return []

        except Exception as e:
            return []

    def _predict_pseudo_labels(self, pseudo_inputs: List[str]) -> List[List[Dict]]:
       
        pseudo_labels = []

        for pseudo_input in pseudo_inputs:
            entities = self._extract_entities_zero_shot(pseudo_input)
            pseudo_labels.append(entities)

        return pseudo_labels

    def _extract_entities_zero_shot(self, text: str) -> List[Dict]:
           
        entity_desc = "\n".join([f"- {k}: {v}" for k, v in self.entity_types.items()])

        prompt = f"""Extract game-related entities from the following text.

Entity Types:
{entity_desc}

Text: {text}

Return the result in the following JSON format only, no other text:
{{
    "entities": [
        {{
            "text": "entity text",
            "type": "ENTITY_TYPE",
            "context": "brief context"
        }}
    ]
}}

If no entities are found, return: {{"entities": []}}
"""

        try:
            response = self.client.chat.completions.create(
                model=self.model,
                messages=[
                    {"role": "system", "content": "You are a game entity extraction expert. Return only valid JSON in the specified format."},
                    {"role": "user", "content": prompt}
                ],
                temperature=0.1,
                max_tokens=800
            )

            response_text = response.choices[0].message.content.strip()
            result = self._parse_json_response(response_text)

            if isinstance(result, dict) and "entities" in result:
                entities = result["entities"]
                if isinstance(entities, list):
                    return entities

            return []

        except Exception as e:
            return []

    def _parse_json_response(self, response_text: str) -> Dict:
                              
        try:
                    
            result = json.loads(response_text)
            if isinstance(result, dict):
                return result
            return {}
        except json.JSONDecodeError:
                          
            import re

                               
            json_match = re.search(r'```json\s*(\{.*?\})\s*```', response_text, re.DOTALL)
            if json_match:
                try:
                    result = json.loads(json_match.group(1))
                    if isinstance(result, dict):
                        return result
                except json.JSONDecodeError:
                    pass

                          
            json_match = re.search(r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}', response_text, re.DOTALL)
            if json_match:
                try:
                    result = json.loads(json_match.group(0))
                    if isinstance(result, dict):
                        return result
                except json.JSONDecodeError:
                    pass

                                    
            entities_match = re.search(r'"entities"\s*:\s*\[(.*?)\]', response_text, re.DOTALL)
            if entities_match:
                try:
                    entities_str = f'[{entities_match.group(1)}]'
                    entities = json.loads(entities_str)
                    return {"entities": entities}
                except json.JSONDecodeError:
                    pass

            return {}

    def _parse_json_entities(self, response_text: str) -> List[Dict]:
                                   
        result = self._parse_json_response(response_text)
        if "entities" in result:
            return result["entities"]
        return []

    def _create_icl_prompt(self, test_text: str, pseudo_examples: List[tuple]) -> str:
       
        entity_desc = "\n".join([f"- {k}: {v}" for k, v in self.entity_types.items()])

                
        demonstrations = []
        for i, (pseudo_input, pseudo_labels) in enumerate(pseudo_examples):
            demo_entities = []
            for entity in pseudo_labels:
                if isinstance(entity, dict) and 'text' in entity and 'type' in entity:
                    demo_entities.append({
                        "text": entity['text'],
                        "type": entity['type'],
                        "context": entity.get('context', '')
                    })

            demo_json = json.dumps({"entities": demo_entities}, ensure_ascii=False, indent=2)
            demonstrations.append(f"Example {i+1}:\nInput: {pseudo_input}\nOutput: {demo_json}")

        demonstrations_text = "\n\n".join(demonstrations)

        prompt = f"""You are a specialized assistant for identifying game-related entities. Learn from the following examples and then extract entities from the test input.

Entity Type Definitions:
{entity_desc}

Here are some examples:

{demonstrations_text}

Now, extract entities from the following test input.

Test Input: {test_text}

Return the result in the following JSON format only, no other text:
{{
    "entities": [
        {{
            "text": "entity text",
            "type": "ENTITY_TYPE",
            "context": "brief context"
        }}
    ]
}}

If no entities are found, return: {{"entities": []}}
"""
        return prompt

    def _create_ner_prompt(self, text: str) -> str:
                         
        entity_desc = "\n".join([f"- {k}: {v}" for k, v in self.entity_types.items()])

        prompt = f"""You are a specialized assistant for identifying game-related entities. Please identify all game-related entities from the following text and classify them according to the specified types.

Entity Type Definitions:
{entity_desc}

Text Content:
{text}

Return the result in the following JSON format only, no other text:
{{
    "entities": [
        {{
            "text": "entity text",
            "type": "ENTITY_TYPE",
            "start": start_position,
            "end": end_position,
            "context": "brief context"
        }}
    ]
}}

Instructions:
1. Only identify clear game-related entities
2. Ensure entity text is accurate
3. Provide accurate position information if possible
4. Include brief contextual information
5. If no entities are found, return: {{"entities": []}}
"""
        return prompt

    def extract_entities(self, text: str) -> List[Dict]:          
        if not self.client:
            return []

        if not text.strip():
            return []

                          
        max_chars = 3000
        if len(text) > max_chars:
            text = text[:max_chars] + "..."

                               
        if self.use_self_icl:
            return self._extract_entities_with_self_icl(text)
        else:
            return self._extract_entities_traditional(text)

    def _extract_entities_with_self_icl(self, text: str) -> List[Dict]:
           
        for attempt in range(self.max_retries):
            try:
                            
                pseudo_inputs = self._generate_pseudo_inputs(text)

                if not pseudo_inputs:
                    return self._extract_entities_traditional(text)

                            
                pseudo_labels = self._predict_pseudo_labels(pseudo_inputs)

                        
                pseudo_examples = []
                for pseudo_input, labels in zip(pseudo_inputs, pseudo_labels):
                    if labels:             
                        pseudo_examples.append((pseudo_input, labels))

                if not pseudo_examples:
                    return self._extract_entities_traditional(text)

                            
                icl_prompt = self._create_icl_prompt(text, pseudo_examples)

                response = self.client.chat.completions.create(
                    model=self.model,
                    messages=[
                        {"role": "system", "content": "You are a professional game entity recognition assistant. Learn from the examples and extract entities accurately."},
                        {"role": "user", "content": icl_prompt}
                    ],
                    temperature=0.1,
                    max_tokens=1500
                )

                response_text = response.choices[0].message.content.strip()
                result = self._parse_json_response(response_text)
                entities = result.get("entities", [])

                         
                validated_entities = self._validate_entities(entities, text)
                return validated_entities

            except Exception as e:
                if attempt < self.max_retries - 1:
                    time.sleep(self.retry_delay)
                else:
                    return self._extract_entities_traditional(text)

    def _extract_entities_traditional(self, text: str) -> List[Dict]:

        prompt = self._create_ner_prompt(text)

        for attempt in range(self.max_retries):
            try:
                response = self.client.chat.completions.create(
                    model=self.model,
                    messages=[
                        {"role": "system", "content": "You are a professional game entity recognition assistant. Extract game-related entities accurately and follow the specified format."},
                        {"role": "user", "content": prompt}
                    ],
                    temperature=0.1,
                    max_tokens=1500
                )

                response_text = response.choices[0].message.content.strip()
                result = self._parse_json_response(response_text)
                entities = result.get("entities", [])

                         
                validated_entities = self._validate_entities(entities, text)
                return validated_entities

            except Exception as e:
                if attempt < self.max_retries - 1:
                    time.sleep(self.retry_delay)
                else:
                    return []

    def _validate_entities(self, entities: List[Dict], original_text: str) -> List[Dict]:
                        
        validated = []

        for entity in entities:
            try:
                        
                if not all(key in entity for key in ["text", "type"]):
                    continue

                            
                if entity["type"] not in self.entity_types:
                    continue

                        
                entity_text = str(entity["text"]).strip()
                if not entity_text or len(entity_text) < 2:
                    continue

                            
                if entity_text.lower() not in original_text.lower():
                    continue

                            
                cleaned_entity = {
                    "text": entity_text,
                    "type": entity["type"],
                    "context": str(entity.get("context", "")).strip()[:200]           
                }

                                 
                if "start" in entity and "end" in entity:
                    try:
                        start = int(entity["start"])
                        end = int(entity["end"])
                        if 0 <= start < end <= len(original_text):
                            cleaned_entity["start"] = start
                            cleaned_entity["end"] = end
                    except (ValueError, TypeError):
                        pass

                validated.append(cleaned_entity)

            except Exception as e:
                continue

        return validated

    def batch_extract_entities(self, text_chunks: List[str], show_progress: bool = True) -> List[List[Dict]]:
      
        if not text_chunks:
            return []

                                  
        if len(text_chunks) == 1 or not self.client:
            return self._batch_extract_serial(text_chunks, show_progress)

                
        return self._batch_extract_concurrent(text_chunks, show_progress)

    def _batch_extract_serial(self, text_chunks: List[str], show_progress: bool = True) -> List[List[Dict]]:
                            
        results = []

        if show_progress:
            from tqdm import tqdm
            text_chunks = tqdm(text_chunks, desc="NER extraction")

        for text in text_chunks:
            entities = self.extract_entities(text)
            results.append(entities)

                            
            if self.client:
                time.sleep(0.1)

        return results

    def _batch_extract_concurrent(self, text_chunks: List[str], show_progress: bool = True) -> List[List[Dict]]:
                      
        results = [None] * len(text_chunks)                
               
        if show_progress:
            from tqdm import tqdm
            pbar = tqdm(total=len(text_chunks), desc="NER concurrent extraction")

        def extract_with_index(args):
                          
            index, text = args
            entities = self.extract_entities(text)
            return index, entities

                                    
        with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
                              
            indexed_chunks = [(i, text) for i, text in enumerate(text_chunks)]
            future_to_index = {
                executor.submit(extract_with_index, chunk): chunk[0]
                for chunk in indexed_chunks
            }

                  
            completed_count = 0
            for future in as_completed(future_to_index):
                try:
                    index, entities = future.result()
                    results[index] = entities
                    completed_count += 1

                    if show_progress:
                        pbar.update(1)
                        pbar.set_postfix({"Completed": f"{completed_count}/{len(text_chunks)}"})
                except Exception as e:
                    index = future_to_index[future]
                    results[index] = []            

        if show_progress:
            pbar.close()

        return results

    def get_entity_statistics(self, all_entities: List[List[Dict]]) -> Dict:
       
        stats = {
            "total_chunks": len(all_entities),
            "total_entities": 0,
            "entity_types": {entity_type: 0 for entity_type in self.entity_types.keys()},
            "unique_entities": set(),
            "chunks_with_entities": 0
        }

        for entities in all_entities:
            if entities:
                stats["chunks_with_entities"] += 1

            for entity in entities:
                stats["total_entities"] += 1
                entity_type = entity.get("type", "UNKNOWN")
                if entity_type in stats["entity_types"]:
                    stats["entity_types"][entity_type] += 1

                                
                unique_key = (entity.get("text", "").lower(), entity_type)
                stats["unique_entities"].add(unique_key)

        stats["unique_entity_count"] = len(stats["unique_entities"])
        stats["unique_entities"] = list(stats["unique_entities"])                  

        return stats
