#!/usr/bin/env python3
"""Unified inference script using GPT API with post-processing"""

import os
import re
import json
from datasets import Dataset
import numpy as np
from transformers import AutoTokenizer
from utils.data_utils import smiles2selfies
from openai import OpenAI
from concurrent.futures import ThreadPoolExecutor
import threading
from tqdm import tqdm

from utils.config import load_config, get_base_parser
from data.data_loader import DataLoader
from data.prompt_builder import PromptBuilder
from utils.model_utils import apply_alpaca_chat_template, apply_round_chat_template
from utils.evaluation_utils import compute_bleu_nltk_batch
from molgeneval import create_score_indie, create_retrosynthesis_score

class PostProcessor:
    """Handle different post-processing logic"""
    
    def __init__(self, task, mol_type="SELFIES"):
        self.task = task
        self.mol_type = mol_type
    
    def process(self, results, test_data, tokenizer):
        """Process results based on task type"""
        clean_results = [res.strip("\n") for res in results]
        
        if "molecule" in self.task:
            return self._process_molecule_with_scores(clean_results, test_data)
        elif "caption" in self.task:
            return self._process_caption_with_bleu(clean_results, test_data, tokenizer)
        elif "reaction_pred" in self.task:
            return self._process_reaction_prediction(clean_results, test_data)
        elif "retro" in self.task:
            return self._process_retrosynthesis(clean_results, test_data)
        elif "class" in self.task:
            return self._process_classification(clean_results, test_data)
        elif self.task in ["synth", "iupac_low"]:
            return self._process_structured_output(clean_results, test_data)
        else:
            return clean_results
    
    def _process_molecule_with_scores(self, results, test_data):
        """Process molecule generation with scores and save as dataset"""
        gt = test_data[self.mol_type]
        results = [r.strip() for r in results]
        scores = create_score_indie(gt, results, self.mol_type)
        # Create dataset with results and individual metric values
        data_dict = test_data.to_dict()
        data_dict["gen_mol"] = results
        
        # Add each metric value list to dataset
        for metric_name, metric_values in scores.items():
            data_dict[metric_name] = metric_values
        
        return Dataset.from_dict(data_dict)
    
    def _process_caption_with_bleu(self, results, test_data, tokenizer):
        """Process caption generation with multiple evaluation metrics"""
        data_dict = test_data.to_dict()
        gt = data_dict.get('desc', [])
        
        metrics = compute_bleu_nltk_batch(tokenizer, results, gt)

        data_dict["gen_desc"] = results
        for metric_name, metric_values in metrics.items():
            data_dict[metric_name] = metric_values
        
        return Dataset.from_dict(data_dict)
    
    def _process_reaction_prediction(self, results, test_data):
        """Process reaction prediction with scores"""
        match = []
        gt = test_data["prod"]
        for i, pred in enumerate(results):
            if pred.strip() == gt[i]:
                match.append(1)
            else:
                match.append(0)
        scores = create_score_indie(gt, results, self.mol_type)
        data_dict = test_data.to_dict()
        print("Acc:", sum(match)/len(match))
        data_dict["gen_prod"] = results
        data_dict["match"] = match
        for metric_name, metric_values in scores.items():
            data_dict[metric_name] = metric_values

        return Dataset.from_dict(data_dict)
    
    def _process_retrosynthesis(self, results, test_data):
        """Process reaction prediction with scores"""
        match = []
        gt = test_data["equa"]
        results = [r.strip() for r in results]
        for i, pred in enumerate(results):
            gt_set = set(gt[i].split("."))
            pred_set = set(pred.strip().split("."))
            if gt_set.issubset(pred_set):
                match.append(1)
            else:
                match.append(0)
        scores = create_score_indie(gt, results, self.mol_type)
        data_dict = test_data.to_dict()
        print("Acc:", sum(match)/len(match))
        data_dict["gen_equa"] = results
        data_dict["match"] = match
        for metric_name, metric_values in scores.items():
            data_dict[metric_name] = metric_values

        return Dataset.from_dict(data_dict)
    
    def _process_classification(self, results, test_data):
        """Process classification results"""
        gt = test_data["class"]
        acc = []
        for i, result in enumerate(results):
            if result.strip().lower() == gt[i].lower():
                acc.append(1)
            else:
                acc.append(0)
        data_dict = test_data.to_dict()
        print("Acc:", sum(acc)/len(acc))
        data_dict["pred"] = results
        data_dict["acc"] = acc
        return Dataset.from_dict(data_dict)
    
    def _process_structured_output(self, results, test_data):
        """Process structured output with XML tags"""
        filtered_data = []
        for i, result in enumerate(results):
            match = re.search(r"<desc>(.*?)</desc>", result, flags=re.DOTALL | re.IGNORECASE)
            if match and len(match.group(1).strip()) > 20:
                data_point = test_data[i].copy()
                data_point["desc"] = match.group(1).strip()
                filtered_data.append(data_point)
        return Dataset.from_list(filtered_data)

def create_few_shot_examples(train_data, task, mol_type, n_shots=0):
    """Create few-shot examples from training data"""
    examples = []
    for i in range(min(n_shots, len(train_data))):
        example = train_data[i]
        if "molecule" in task:
            examples.append(f"Input: {example['desc']}\nOutput: {example[mol_type]}")
        elif "caption" in task:
            examples.append(f"Input: {example[mol_type]}\nOutput: {example['desc']}")
        elif "reaction_pred" in task:
            examples.append(f"Input: {example['equa']}\nOutput: {example['prod']}")
        elif "retro" in task:
            examples.append(f"Input: {example['prod']}\nOutput: {example['equa']}")
        elif "class" in task:
            examples.append(f"Input: {example[mol_type]}\nOutput: {example['class']}")
    return "\n\n".join(examples)

def gpt_single_request(messages, config):
    """Single GPT request"""
    client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
    
    response = client.chat.completions.create(
        model=getattr(config, 'gpt_model', 'gpt-4.1-mini'),
        messages=messages,
        temperature=getattr(config, 'temperature', 0.7),
        max_completion_tokens=getattr(config, 'max_completion_tokens', 10000),
        # top_p=getattr(config, 'top_p', 0.9)
    )
    return response.choices[0].message.content

def gpt_inference(messages_list, config):
    """Use OpenAI GPT for inference with multithreading"""
    max_workers = getattr(config, 'max_workers', 10)
    
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = [executor.submit(gpt_single_request, messages, config) for messages in messages_list]
        results = [future.result() for future in tqdm(futures, desc="GPT Inference")]
    
    return results

def main():
    parser = get_base_parser()
    args = parser.parse_args()
    config = load_config(args.config, args.opts)
    
    # Load data
    data_loader = DataLoader()
    test_data = data_loader.load_multiple_datasets(
        config.test_datasets,
        getattr(config, 'test_limits', {}),
        getattr(config, 'test_processing', {})
    )
    
    # Load training data for few-shot examples
    train_data = data_loader.load_multiple_datasets(
        config.train_datasets,
        getattr(config, 'train_limits', {}),
        getattr(config, 'train_processing', {})
    )
    if "SELFIES" not in test_data.column_names and "SMILES" in test_data.column_names and config.model_mol_type=="SELFIES":
        test_data = smiles2selfies(test_data)
    if "prod" in test_data.column_names and "equa" in test_data.column_names and config.model_mol_type=="SELFIES":
        test_data = smiles2selfies(test_data, "prod", "prod")
        test_data = smiles2selfies(test_data, "equa", "equa")
    
    # Handle special data preparation for certain tasks
    if config.task == "synth":
        # Add random reference descriptions for synthesis task
        text_in = np.array(test_data["desc"])
        rand_indices = np.random.randint(0, len(text_in), size=len(text_in))
        test_data = test_data.add_column("ref_desc", text_in[rand_indices].tolist())
    
    # Build prompts
    prompt_builder = PromptBuilder(config.model_mol_type)
    test_data = prompt_builder.build_prompts(test_data, config.task, is_generation=True)
    
    # Setup tokenizer for formatting (use a generic one since we're not using the actual model)
    tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
    tokenizer.pad_token = tokenizer.bos_token
    tokenizer.padding_side = "left"
    
    if config.use_alpaca:
        tokenizer.apply_chat_template = apply_alpaca_chat_template.__get__(tokenizer)
    elif config.use_chemdfm:
        tokenizer.apply_chat_template = apply_round_chat_template.__get__(tokenizer)
    
    # Create few-shot examples from training data
    few_shot_examples = create_few_shot_examples(train_data, config.task, config.model_mol_type)
    print(few_shot_examples)
    
    # Add few-shot examples to prompts and keep as message lists
    def add_few_shot_to_messages(samples):
        formatted_messages = []
        for msg_list in samples["prompt"]:
            messages = msg_list.copy()
            # Check if system message exists
            has_system = any(msg["role"] == "system" for msg in messages)
            if has_system:
                # Add to existing system message
                for msg in messages:
                    if msg["role"] == "system":
                        msg["content"] = f"{msg['content']}\n\nHere are some examples:\n{few_shot_examples}"
                        break
            else:
                # Create new system message
                messages.insert(0, {"role": "system", "content": f"Here are some examples:\n{few_shot_examples}"})
            formatted_messages.append(messages)
        return {"messages": formatted_messages}
    
    test_data = test_data.map(add_few_shot_to_messages, batched=True)
    print(test_data[0])
    
    # GPT inference
    all_outputs = gpt_inference(test_data["messages"], config)
    
    # Post-process results
    processor = PostProcessor(config.task, config.model_mol_type)
    
    # Remove intermediate columns
    test_data = test_data.remove_columns(["messages", "prompt"])
    eval_tokenizer = AutoTokenizer.from_pretrained("allenai/scibert_scivocab_uncased")
    results = processor.process(all_outputs, test_data, eval_tokenizer)
    
    # Save results if specified
    if hasattr(config, 'output_root_path') and config.output_root_path:
        # Create structured output path: {output_root_path}/{task}/gpt/{gpt_model}
        gpt_model_name = getattr(config, 'gpt_model', 'gpt-4.1-mini')
        output_path = os.path.join(config.output_root_path, config.task, "gpt", gpt_model_name)
        os.makedirs(output_path, exist_ok=True)
        
        if isinstance(results, Dataset):
            results.save_to_disk(output_path)
        else:
            with open(os.path.join(output_path, 'results.json'), 'w') as f:
                json.dump(results, f, indent=2)
    
    print("Inference completed successfully!")
    return results

if __name__ == "__main__":
    main()