#!/usr/bin/env python3
"""RL training with accelerate instead of torchrun"""

import torch
import os
import requests
import json
import datasets
import numpy as np
from concurrent.futures import ThreadPoolExecutor, as_completed
from transformers import AutoTokenizer
from trl import GRPOConfig, GRPOTrainer
from datasets import Dataset
from accelerate import Accelerator

from utils.config import load_config, get_base_parser
from data.data_loader import DataLoader
from data.prompt_builder import PromptBuilder
from utils.data_utils import smiles2selfies
from utils.training_utils import setup_experiment, setup_tokenizer, setup_model, setup_lora, save_experiment, with_telegram_notifications
from utils.evaluation_utils import identify_smiles_components, compute_bleu_nltk_batch
from molgeneval import parse_molecules, is_same_mol, create_score_indie

class VLLMRewardFunction:
    """Reward function using vLLM judge server"""
    __name__ = "VLLMReward"
    
    def __init__(self, config, tokenizer, judge_url="http://localhost:8000/v1/completions"):
        self.config = config
        self.judge_url = judge_url
        self.task_type = config.rl_task_type
        self.task = config.rl_task
        self.tokenizer = tokenizer
        self.min_reward = -8.0
        self.prompt_builder = PromptBuilder(config.model_mol_type)
        self.add_gt = config.add_gt
        
        if self.task_type == "mol2cap":
            self.template_name = self.task
        else:
            self.template_name = self.task
    
    def __call__(self, completions, **kwargs):
        if self.task_type == "mol2cap":
            return self._reward_mol2cap(completions, **kwargs)
        elif self.task_type == "cap2mol":
            return self._reward_cap2mol(completions, **kwargs)
        elif self.task_type == "equa2prod":
            return self._reward_equa2prod(completions, **kwargs)
        elif self.task_type == "prod2equa":
            return self._reward_prod2equa(completions, **kwargs)
        elif self.task_type == "class2mol":
            return self._reward_class2mol(completions, **kwargs)
        elif self.task_type == "mol2class":
            return self._reward_mol2class(completions, **kwargs)
    
    def _compute_single_reward(self, prompt):
        prompt_messages = prompt[:-1]
        input_prompt = self.tokenizer.decode(self.tokenizer.apply_chat_template(prompt, tokenize=True)[1:])
        prompt_ids = self.tokenizer.apply_chat_template(prompt_messages, tokenize=True, add_generation_prompt=True)
        prompt_length = len(prompt_ids)
        
        payload = {
            "prompt": input_prompt,
            "max_tokens": 1,
            "temperature": 0.0,
            "logprobs": 0,
            "echo": True
        }
        
        try:
            response = requests.post(self.judge_url, json=payload)
            response.raise_for_status()
            result = response.json()["choices"][0]
            logprobs = result["logprobs"]["token_logprobs"]
            
            if logprobs:
                return sum(logprobs[prompt_length-1:-2]) / max((len(logprobs)-prompt_length+1-2), 1)
            else:
                return self.min_reward
        except Exception as e:
            print(f"Error computing reward: {e}")
            return self.min_reward
    
    def _compute_logprob_rewards(self, prompts):
        scores = [0.0] * len(prompts)
        with ThreadPoolExecutor(max_workers=min(len(prompts), 10)) as executor:
            future_to_idx = {executor.submit(self._compute_single_reward, prompt): i 
                           for i, prompt in enumerate(prompts)}
            for future in as_completed(future_to_idx):
                idx = future_to_idx[future]
                scores[idx] = future.result()
        return scores
    
    def _reward_cap2mol(self, completions, **kwargs):
        generated_captions = [c[-1]['content'].strip() for c in completions]
        molecules = kwargs[self.config.model_mol_type]
        data = Dataset.from_dict({
            self.config.model_mol_type: molecules,
            "desc": generated_captions
        })
        if self.add_gt:
            tru_desc = kwargs["desc"]
            scores = compute_bleu_nltk_batch(self.tokenizer, generated_captions, tru_desc, verbose=False)
            matches = 0
            for k in scores:
                matches += np.array(scores[k])
            matches = list(matches/6)
        else:
            matches = [0] * len(generated_captions)
        dialog_data = self.prompt_builder.build_prompts(data, self.template_name, is_generation=False)
        log_reward = self._compute_logprob_rewards(dialog_data["prompt"])
        return [log_reward[i] + matches[i]*3 for i in range(len(log_reward))]
    
    def _reward_mol2cap(self, completions, **kwargs):
        generated_molecules = [c[-1]['content'].strip() for c in completions]
        is_mol = []
        for mol in generated_molecules:
            is_mol.append(int(identify_smiles_components(mol)==1))
        descriptions = kwargs["desc"]
        data = Dataset.from_dict({
            self.config.model_mol_type: generated_molecules,
            "desc": descriptions
        })
        if self.add_gt:
            tru_mol = kwargs[self.config.model_mol_type]
            scores = create_score_indie(tru_mol, generated_molecules, self.config.model_mol_type, verbose=False)
            matches = 0
            for k in ["bleu", "maccs_similarity", "rdk_similarity", 'morgan_similarity']:
                matches += np.array(scores[k])
            matches = list(matches/4)
        else:
            matches = [0] * len(generated_molecules)
        dialog_data = self.prompt_builder.build_prompts(data, self.template_name, is_generation=False)
        log_reward = self._compute_logprob_rewards(dialog_data["prompt"])
        return [is_mol[i]*4 + log_reward[i] + matches[i]*2 for i in range(len(log_reward))]

    
    def _reward_equa2prod(self, completions, **kwargs):
        generated_equa = [c[-1]['content'].strip() for c in completions]
        is_chem = []
        for mol in generated_equa:
            is_chem.append(int(identify_smiles_components(mol)>=1))
        if self.add_gt:
            tru_equa = kwargs["equa"]
            scores = create_score_indie(tru_equa, generated_equa, self.config.model_mol_type, verbose=False)
            matches = 0
            for k in ["bleu", "maccs_similarity", "rdk_similarity", 'morgan_similarity']:
                matches += np.array(scores[k])
            matches = list(matches/4)
        else:
            matches = [0] * len(generated_equa)
        data = Dataset.from_dict({
            "prod": kwargs["prod"],
            "equa": generated_equa
        })
        dialog_data = self.prompt_builder.build_prompts(data, self.template_name, is_generation=False)
        log_reward = self._compute_logprob_rewards(dialog_data["prompt"])
        return [log_reward[i] + is_chem[i] + matches[i] for i in range(len(log_reward))]
    
    def _reward_prod2equa(self, completions, **kwargs):
        generated_prod = [c[-1]['content'].strip() for c in completions]
        if self.add_gt:
            tru_prod = kwargs["prod"]
            scores = create_score_indie(tru_prod, generated_prod, self.config.model_mol_type, verbose=False)
            matches = 0
            for k in ["bleu", "maccs_similarity", "rdk_similarity", 'morgan_similarity']:
                matches += np.array(scores[k])
            matches = list(matches/4)
        else:
            matches = [0] * len(generated_prod)
        is_prod = []
        for mol in generated_prod:
            is_prod.append(int(identify_smiles_components(mol)==1))
        data = Dataset.from_dict({
            "prod": generated_prod,
            "equa": kwargs["equa"]
        })
        dialog_data = self.prompt_builder.build_prompts(data, self.template_name, is_generation=False)
        log_reward = self._compute_logprob_rewards(dialog_data["prompt"])
        return [log_reward[i] + is_prod[i] + matches[i] for i in range(len(log_reward))]
    
    def _reward_class2mol(self, completions, **kwargs):
        generated_class = [c[-1]['content'].strip() for c in completions]
        molecules = kwargs[self.config.model_mol_type]
        data = Dataset.from_dict({
            self.config.model_mol_type: molecules,
            "class": generated_class
        })
        dialog_data = self.prompt_builder.build_prompts(data, self.template_name, is_generation=False)
        return self._compute_logprob_rewards(dialog_data["prompt"])

    def _reward_mol2class(self, completions, **kwargs):
        generated_molecule = [c[-1]['content'].strip() for c in completions]
        data = Dataset.from_dict({
            self.config.model_mol_type: generated_molecule,
            "class": kwargs["class"]
        })
        dialog_data = self.prompt_builder.build_prompts(data, self.template_name, is_generation=False)
        return self._compute_logprob_rewards(dialog_data["prompt"])


@with_telegram_notifications
def main():
    parser = get_base_parser()
    args = parser.parse_args()
    config = load_config(args.config, args.opts)
    
    # Initialize accelerator
    accelerator = Accelerator()
    
    setup_experiment(config, accelerator.process_index, "molgen-rl-vllm")
    
    # Load data
    data_loader = DataLoader()
    if hasattr(config, 'dataset_tasks') and config.dataset_tasks:
        train_data = data_loader.load_datasets_with_tasks(
            config.dataset_tasks,
            getattr(config, 'dataset_limits', {}),
            getattr(config, 'dataset_processing', {})
        )
        if "SELFIES" not in train_data.column_names and "SMILES" in train_data.column_names and config.model_mol_type=="SELFIES":
            train_data = smiles2selfies(train_data)
        # Build prompts per dataset
        prompt_builder = PromptBuilder(config.model_mol_type)
        train_data = prompt_builder.build_prompts_per_dataset(train_data, is_generation=True)
    else:
        train_data = data_loader.load_multiple_datasets(
            config.target_datasets,
            getattr(config, 'dataset_limits', {}),
            getattr(config, 'dataset_processing', {})
        )
        if "SELFIES" not in train_data.column_names and "SMILES" in train_data.column_names and config.model_mol_type=="SELFIES":
            train_data = smiles2selfies(train_data)
        # Build prompts globally
        prompt_builder = PromptBuilder(config.model_mol_type)
        train_data = prompt_builder.build_prompts(train_data, config.tasks, is_generation=True)
    # Setup tokenizer
    tokenizer = setup_tokenizer(config)
    tokenizer.padding_side = "left"
    
    # Setup model
    policy_model, lora_hist = setup_model(config)
    model = setup_lora(policy_model, config)
    
    # Training arguments
    training_args = GRPOConfig(
        output_dir=os.path.join(config.exp_save_dir, "hf"),
        save_strategy="no",
        num_train_epochs=config.epochs,
        per_device_train_batch_size=config.batch_size,
        num_generations=config.num_generations,
        gradient_accumulation_steps=config.gradient_accumulation_steps,
        learning_rate=config.learning_rate,
        weight_decay=config.weight_decay,
        max_grad_norm=config.max_grad_norm,
        warmup_ratio=config.warmup_ratio,
        lr_scheduler_type=config.lr_scheduler_type,
        logging_steps=config.logging_steps,
        save_steps=config.save_steps,
        report_to="wandb" if accelerator.is_main_process else None,
        remove_unused_columns=False,
        ddp_find_unused_parameters=False,
        bf16=True,
        dataloader_pin_memory=False,
        dataloader_num_workers=2,
        top_p=config.top_p,
        temperature=config.temperature,
        top_k=config.top_k,
        beta=0.08,
        use_vllm=config.use_vllm,
        vllm_server_base_url="http://localhost:8001" if config.use_vllm else None,
    )
    
    # Setup reward function
    reward_func = VLLMRewardFunction(config, tokenizer)
    
    # Clean data
    unwanted_cols = ["chosen", "rejected", "completion", "messages"]
    cols_to_remove = [col for col in unwanted_cols if col in train_data.column_names]
    if cols_to_remove:
        train_data = train_data.remove_columns(cols_to_remove)
    datasets.disable_progress_bar()
    
    # Trainer
    trainer = GRPOTrainer(
        model=model,
        processing_class=tokenizer,
        reward_funcs=reward_func,
        args=training_args,
        train_dataset=train_data
    )
    
    # Train
    trainer.train()
    
    # Save
    save_experiment(config, model, lora_hist, accelerator.process_index)

if __name__ == "__main__":
    main()