#!/usr/bin/env python3
"""Unified inference script with LoRA merging and post-processing"""

import os
import re
import json
import multiprocessing as mp
from datasets import Dataset
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from utils.data_utils import smiles2selfies
from vllm import LLM, SamplingParams

from utils.config import load_config, get_base_parser
from data.data_loader import DataLoader
from data.prompt_builder import PromptBuilder
from utils.model_utils import load_lora_hist, apply_alpaca_chat_template, apply_round_chat_template
from utils.evaluation_utils import compute_bleu_nltk_batch
from molgeneval import create_score_indie, create_retrosynthesis_score

class PostProcessor:
    """Handle different post-processing logic"""
    
    def __init__(self, task, mol_type="SELFIES"):
        self.task = task
        self.mol_type = mol_type
    
    def process(self, results, test_data, tokenizer):
        """Process results based on task type"""
        clean_results = [res.outputs[0].text.strip("\n") for res in results]
        
        if "molecule" in self.task:
            return self._process_molecule_with_scores(clean_results, test_data)
        elif "caption" in self.task:
            return self._process_caption_with_bleu(clean_results, test_data, tokenizer)
        elif "reaction_pred" in self.task:
            return self._process_reaction_prediction(clean_results, test_data)
        elif "retro" in self.task:
            return self._process_retrosynthesis(clean_results, test_data)
        elif "class" in self.task:
            return self._process_classification(clean_results, test_data)
        elif self.task in ["synth", "iupac_low"]:
            return self._process_structured_output(clean_results, test_data)
        else:
            return clean_results
    
    def _process_molecule_with_scores(self, results, test_data):
        """Process molecule generation with scores and save as dataset"""
        gt = test_data[self.mol_type]
        results = [r.strip() for r in results]
        scores = create_score_indie(gt, results, self.mol_type)
        # Create dataset with results and individual metric values
        data_dict = test_data.to_dict()
        data_dict["gen_mol"] = results
        
        # Add each metric value list to dataset
        for metric_name, metric_values in scores.items():
            data_dict[metric_name] = metric_values
        
        return Dataset.from_dict(data_dict)
    
    def _process_caption_with_bleu(self, results, test_data, tokenizer):
        """Process caption generation with multiple evaluation metrics"""
        data_dict = test_data.to_dict()
        gt = data_dict.get('desc', [])
        
        metrics = compute_bleu_nltk_batch(tokenizer, results, gt)

        data_dict["gen_desc"] = results
        for metric_name, metric_values in metrics.items():
            data_dict[metric_name] = metric_values
        
        return Dataset.from_dict(data_dict)
    
    def _process_reaction_prediction(self, results, test_data):
        """Process reaction prediction with scores"""
        match = []
        gt = test_data["prod"]
        for i, pred in enumerate(results):
            if pred.strip() == gt[i]:
                match.append(1)
            else:
                match.append(0)
        scores = create_score_indie(gt, results, self.mol_type)
        data_dict = test_data.to_dict()
        print("Acc:", sum(match)/len(match))
        data_dict["gen_prod"] = results
        data_dict["match"] = match
        for metric_name, metric_values in scores.items():
            data_dict[metric_name] = metric_values

        return Dataset.from_dict(data_dict)
    
    def _process_retrosynthesis(self, results, test_data):
        """Process reaction prediction with scores"""
        match = []
        gt = test_data["equa"]
        results = [r.strip() for r in results]
        for i, pred in enumerate(results):
            gt_set = set(gt[i].split("."))
            pred_set = set(pred.strip().split("."))
            if gt_set.issubset(pred_set):
                match.append(1)
            else:
                match.append(0)
        scores = create_score_indie(gt, results, self.mol_type)
        data_dict = test_data.to_dict()
        print("Acc:", sum(match)/len(match))
        data_dict["gen_equa"] = results
        data_dict["match"] = match
        for metric_name, metric_values in scores.items():
            data_dict[metric_name] = metric_values

        return Dataset.from_dict(data_dict)
    
    def _process_classification(self, results, test_data):
        """Process classification results"""
        gt = test_data["class"]
        acc = []
        for i, result in enumerate(results):
            if result.strip().lower() == gt[i].lower():
                acc.append(1)
            else:
                acc.append(0)
        data_dict = test_data.to_dict()
        print("Acc:", sum(acc)/len(acc))
        data_dict["pred"] = results
        data_dict["acc"] = acc
        return Dataset.from_dict(data_dict)
    
    def _process_structured_output(self, results, test_data):
        """Process structured output with XML tags"""
        filtered_data = []
        for i, result in enumerate(results):
            match = re.search(r"<desc>(.*?)</desc>", result, flags=re.DOTALL | re.IGNORECASE)
            if match and len(match.group(1).strip()) > 20:
                data_point = test_data[i].copy()
                data_point["desc"] = match.group(1).strip()
                filtered_data.append(data_point)
        return Dataset.from_list(filtered_data)

def inference_worker(rank, prompts, model_path, params_dict, return_dict):
    """Worker function for multi-GPU inference"""
    os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)
    os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
    
    engine = LLM(
        model_path, 
        gpu_memory_utilization=0.8,
        max_model_len=8192,
        trust_remote_code=True,
    )
    
    sampling_params = SamplingParams(
        top_p=params_dict.get("top_p", 0.9),
        temperature=params_dict.get("temperature", 0.7),
        max_tokens=params_dict.get("max_tokens", 1024),
        top_k=params_dict.get("top_k", 20),
    )
    
    outputs = engine.generate(prompts, sampling_params)
    return_dict[rank] = outputs

def merge_and_save_model(config):
    """Merge LoRA weights and save to temporary directory"""
    from utils.model_utils import merge_and_save_model as merge_model
    return merge_model(config.base_model_name, config.load_directory, config.temp_model_dir, config.cache_dir)

def main():
    parser = get_base_parser()
    args = parser.parse_args()
    config = load_config(args.config, args.opts)
    
    # Merge LoRA and save model
    if hasattr(config, 'load_directory') and config.load_directory:
        model_path = merge_and_save_model(config)
    else:
        model_path = config.base_model_name
    
    # Load data
    data_loader = DataLoader()
    test_data = data_loader.load_multiple_datasets(
        config.test_datasets,
        getattr(config, 'test_limits', {}),
        getattr(config, 'test_processing', {})
    )
    if "SELFIES" not in test_data.column_names and "SMILES" in test_data.column_names and config.model_mol_type=="SELFIES":
        test_data = smiles2selfies(test_data)
    if "prod" in test_data.column_names and "equa" in test_data.column_names and config.model_mol_type=="SELFIES":
        test_data = smiles2selfies(test_data, "prod", "prod")
        test_data = smiles2selfies(test_data, "equa", "equa")
    
    # Handle special data preparation for certain tasks
    if config.task == "synth":
        # Add random reference descriptions for synthesis task
        text_in = np.array(test_data["desc"])
        rand_indices = np.random.randint(0, len(text_in), size=len(text_in))
        test_data = test_data.add_column("ref_desc", text_in[rand_indices].tolist())
    
    # Build prompts
    prompt_builder = PromptBuilder(config.model_mol_type)
    test_data = prompt_builder.build_prompts(test_data, config.task, is_generation=True)
    
    # Setup tokenizer for formatting
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    tokenizer.pad_token = tokenizer.bos_token
    tokenizer.padding_side = "left"
    
    if config.use_alpaca:
        tokenizer.apply_chat_template = apply_alpaca_chat_template.__get__(tokenizer)
    elif config.use_chemdfm:
        tokenizer.apply_chat_template = apply_round_chat_template.__get__(tokenizer)
    
    # Format for inference
    test_data = prompt_builder.format_for_inference(test_data, tokenizer, thinking=getattr(config, 'thinking', None))
    print(test_data[0])
    # Multi-GPU inference
    num_gpus = torch.cuda.device_count()
    if num_gpus > 1:
        test_data_shards = [test_data.shard(num_shards=num_gpus, index=i) for i in range(num_gpus)]
        data_slices = [shard.to_dict()["messages"] for shard in test_data_shards]
        
        visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "0,1,2,3").split(",")[:num_gpus]
        manager = mp.Manager()
        return_dict = manager.dict()
        processes = []
        
        for rank in range(num_gpus):
            p = mp.Process(
                target=inference_worker,
                args=(visible_devices[rank], data_slices[rank], model_path, vars(config), return_dict)
            )
            p.start()
            processes.append(p)
        
        for p in processes:
            p.join()
        
        # Collect results
        all_outputs = []
        for rank in range(num_gpus):
            all_outputs.extend(return_dict[visible_devices[rank]])
    else:
        # Single GPU inference
        engine = LLM(model_path, download_dir=config.cache_dir)
        sampling_params = SamplingParams(
            top_p=config.top_p,
            temperature=config.temperature,
            max_tokens=config.max_tokens
        )
        all_outputs = engine.generate(test_data["messages"], sampling_params)
    
    # Post-process results
    processor = PostProcessor(config.task, config.model_mol_type)
    
    # Remove intermediate columns
    test_data = test_data.remove_columns(["messages", "prompt"])
    eval_tokenizer = AutoTokenizer.from_pretrained("allenai/scibert_scivocab_uncased")
    results = processor.process(all_outputs, test_data, eval_tokenizer)
    
    # Save results if specified
    if hasattr(config, 'output_root_path') and config.output_root_path:
        # Create structured output path: {output_root_path}/{task}/{base_model_name}/{load_directory}
        load_dir_name = os.path.basename(config.load_directory) if config.load_directory else "base_model"
        base_model_name = os.path.basename(config.base_model_name)
        output_path = os.path.join(config.output_root_path, config.task, base_model_name, load_dir_name)
        os.makedirs(output_path, exist_ok=True)
        
        if isinstance(results, Dataset):
            results.save_to_disk(output_path)
        else:
            with open(os.path.join(output_path, 'results.json'), 'w') as f:
                json.dump(results, f, indent=2)
    
    print("Inference completed successfully!")
    return results

if __name__ == "__main__":
    main()