import json
import os
import logging
from typing import List, Dict, Tuple
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from openai import OpenAI
import argparse
from tqdm import tqdm
import time

      
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)


class DuplicationDetector:
                

    def __init__(self, openai_api_key: str = None, similarity_threshold: float = 0.9, base_url: str = None):         
        self.client = OpenAI(api_key=openai_api_key, base_url=base_url) if openai_api_key else OpenAI()
        self.similarity_threshold = similarity_threshold
        self.embedding_model = "text-embedding-3-small"                

    def get_embeddings(self, texts: List[str], batch_size: int = 100) -> np.ndarray:      
        embeddings = []

        logger.info(f"Getting embeddings for {len(texts)} texts...")

        for i in tqdm(range(0, len(texts), batch_size), desc="Getting embeddings"):
            batch_texts = texts[i:i + batch_size]

            try:
                response = self.client.embeddings.create(
                    model=self.embedding_model,
                    input=batch_texts
                )

                batch_embeddings = [item.embedding for item in response.data]
                embeddings.extend(batch_embeddings)

                              
                time.sleep(0.1)

            except Exception as e:
                logger.error(f"Failed to get embeddings: {e}")
                                      
                embeddings.extend([[0.0] * 1536] * len(batch_texts))

        return np.array(embeddings)

    def find_duplicates(self, texts: List[str], indices: List[int] = None) -> List[Tuple[int, int, float]]:  
        if indices is None:
            indices = list(range(len(texts)))

                       
        embeddings = self.get_embeddings(texts)

                 
        similarity_matrix = cosine_similarity(embeddings)

        duplicates = []
        processed = set()

        logger.info("Finding duplicates...")
        for i in range(len(texts)):
            if i in processed:
                continue

            for j in range(i + 1, len(texts)):
                if j in processed:
                    continue

                similarity = similarity_matrix[i][j]
                if similarity >= self.similarity_threshold:
                    duplicates.append((indices[i], indices[j], similarity))
                    processed.add(j)                 

        return duplicates

    def detect_template_duplicates(self, template_file: str) -> Dict:       
        logger.info(f"Starting template duplicate detection: {template_file}")

              
        templates = []
        with open(template_file, 'r', encoding='utf-8') as f:
            for line in f:
                templates.append(json.loads(line.strip()))

        logger.info(f"Loaded {len(templates)} question templates")

                         
        texts = []
        indices = []
        for i, template in enumerate(templates):
                         
            combined_text = f"{template.get('template', '')} {template.get('description', '')}"
            texts.append(combined_text)
            indices.append(i)

              
        duplicates = self.find_duplicates(texts, indices)

                
        duplicates_by_type = {}
        for idx1, idx2, similarity in duplicates:
            question_type = templates[idx1].get('question_type', 'unknown')
            if question_type not in duplicates_by_type:
                duplicates_by_type[question_type] = []
            duplicates_by_type[question_type].append({
                'index1': idx1,
                'index2': idx2,
                'similarity': similarity,
                'template1': templates[idx1]['template'],
                'template2': templates[idx2]['template'],
                'description1': templates[idx1].get('description', ''),
                'description2': templates[idx2].get('description', '')
            })

        return {
            'total_templates': len(templates),
            'duplicate_pairs': len(duplicates),
            'duplicates_by_type': duplicates_by_type,
            'to_remove_indices': [idx2 for idx1, idx2, similarity in duplicates]
        }

    def detect_role_duplicates(self, role_file: str) -> Dict:  
        logger.info(f"Starting role duplicate detection: {role_file}")

              
        roles = []
        with open(role_file, 'r', encoding='utf-8') as f:
            for line in f:
                roles.append(json.loads(line.strip()))

        logger.info(f"Loaded {len(roles)} role information")

                     
        exact_duplicates = self._find_exact_duplicates(roles)

                         
        descriptions = [role.get('player_description', '') for role in roles]
        indices = list(range(len(roles)))

                
        semantic_duplicates = self.find_duplicates(descriptions, indices)

              
        all_duplicates = exact_duplicates + semantic_duplicates

                  
        unique_duplicates = []
        seen_pairs = set()
        for dup in all_duplicates:
            if isinstance(dup, tuple) and len(dup) >= 2:
                pair = tuple(sorted([dup[0], dup[1]]))
                if pair not in seen_pairs:
                    seen_pairs.add(pair)
                    unique_duplicates.append(dup)

        return {
            'total_roles': len(roles),
            'exact_duplicate_pairs': len(exact_duplicates),
            'semantic_duplicate_pairs': len(semantic_duplicates),
            'total_duplicate_pairs': len(unique_duplicates),
            'duplicates': unique_duplicates,
            'to_remove_indices': [idx2 for idx1, idx2, *_ in unique_duplicates]
        }

    def _find_exact_duplicates(self, roles: List[Dict]) -> List[Tuple[int, int, float]]:
                         
        exact_duplicates = []
        processed = set()

        for i in range(len(roles)):
            if i in processed:
                continue

            for j in range(i + 1, len(roles)):
                if j in processed:
                    continue

                                            
                if roles[i].get('player_description') == roles[j].get('player_description'):
                    exact_duplicates.append((i, j, 1.0))
                    processed.add(j)

        return exact_duplicates

    def remove_duplicates(self, input_file: str, output_file: str, to_remove_indices: List[int]):    
        logger.info(f"Removing duplicates and saving to: {output_file}")

        to_remove_set = set(to_remove_indices)

        with open(input_file, 'r', encoding='utf-8') as infile, \
                open(output_file, 'w', encoding='utf-8') as outfile:

            for i, line in enumerate(infile):
                if i not in to_remove_set:
                    outfile.write(line)

        original_count = i + 1       
        removed_count = len(to_remove_indices)
        remaining_count = original_count - removed_count

        logger.info(f"Original entries: {original_count}")
        logger.info(f"Removed entries: {removed_count}")
        logger.info(f"Remaining entries: {remaining_count}")


def main():
             
    parser = argparse.ArgumentParser(description='Game RAG Dataset Deduplication Detection Tool')
    parser.add_argument('--template-file',
                        default='../data/pubgm/question_templates.jsonl',
                        help='Question template file path')
    parser.add_argument('--role-file',
                        default='../data/pubgm/question_data_roles.jsonl',
                        help='Role information file path')
    parser.add_argument('--similarity-threshold', type=float, default=0.7,
                        help='Similarity threshold (default: 0.9)')
    parser.add_argument('--remove-duplicates', action='store_true',
                        help='Whether to remove duplicates and save new file')
    parser.add_argument('--api-key', default='{your api_key}',
                        help='OpenAI API key')
    parser.add_argument('--base-url', default='{your base_url}',
                        help='OpenAI API URL')

    args = parser.parse_args()

            
    detector = DuplicationDetector(
        openai_api_key=args.api_key,
        similarity_threshold=args.similarity_threshold,
        base_url=args.base_url
    )

              
    if os.path.exists(args.template_file):
        logger.info("=" * 50)
        logger.info("Detecting Question Template Duplicates")
        logger.info("=" * 50)

        template_results = detector.detect_template_duplicates(args.template_file)

                
        template_dir = os.path.dirname(args.template_file)
        template_report_file = os.path.join(template_dir, 'template_duplicates_report.json')
        with open(template_report_file, 'w', encoding='utf-8') as f:
            json.dump(template_results, f, ensure_ascii=False, indent=2)

        logger.info(f"Template detection results saved to: {template_report_file}")
        logger.info(f"Total templates: {template_results['total_templates']}")
        logger.info(f"Duplicate pairs: {template_results['duplicate_pairs']}")

                   
        for qtype, duplicates in template_results['duplicates_by_type'].items():
            logger.info(f"  {qtype}: {len(duplicates)} duplicate pairs")

               
        if args.remove_duplicates:
                                        
            template_dir = os.path.dirname(args.template_file)
            template_basename = os.path.basename(args.template_file)
            template_name, template_ext = os.path.splitext(template_basename)
            output_template_file = os.path.join(template_dir, f'{template_name}_dup_mov{template_ext}')

            detector.remove_duplicates(
                args.template_file,
                output_template_file,
                template_results['to_remove_indices']
            )

              
    if os.path.exists(args.role_file):
        logger.info("=" * 50)
        logger.info("Detecting Role Information Duplicates")
        logger.info("=" * 50)

        role_results = detector.detect_role_duplicates(args.role_file)

                
        role_dir = os.path.dirname(args.role_file)
        role_report_file = os.path.join(role_dir, 'role_duplicates_report.json')
        with open(role_report_file, 'w', encoding='utf-8') as f:
            json.dump(role_results, f, ensure_ascii=False, indent=2)

        logger.info(f"Role detection results saved to: {role_report_file}")
        logger.info(f"Total roles: {role_results['total_roles']}")
        logger.info(f"Exact duplicate pairs: {role_results['exact_duplicate_pairs']}")
        logger.info(f"Semantic duplicate pairs: {role_results['semantic_duplicate_pairs']}")
        logger.info(f"Total duplicate pairs: {role_results['total_duplicate_pairs']}")

               
        if args.remove_duplicates:
                                        
            role_dir = os.path.dirname(args.role_file)
            role_basename = os.path.basename(args.role_file)
            role_name, role_ext = os.path.splitext(role_basename)
            output_role_file = os.path.join(role_dir, f'{role_name}_dup_mov{role_ext}')

            detector.remove_duplicates(
                args.role_file,
                output_role_file,
                role_results['to_remove_indices']
            )


if __name__ == "__main__":
    main()
