#!/usr/bin/env python3
"""Unified reinforcement learning training script"""

import torch
import os
import numpy as np
import datasets
from concurrent.futures import ThreadPoolExecutor, as_completed
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import GRPOConfig, GRPOTrainer
from accelerate import Accelerator

from utils.config import load_config, get_base_parser
from data.data_loader import DataLoader, smiles2selfies
from data.prompt_builder import PromptBuilder
from utils.training_utils import setup_experiment, setup_tokenizer, setup_model, setup_lora, save_experiment, with_telegram_notifications
from utils.model_utils import load_lora_hist
from utils.evaluation_utils import compute_reward_llama3, compute_bleu_nltk_batch, create_score_indie, identify_smiles_components
from prompts import get_mol_gen_function, get_mol_cap_function, get_mol_plain_function, get_cap_plain_function

class RewardFunction:
    """Unified reward function for different RL tasks"""
    __name__ = "EntropyReward"
    
    def __init__(self, config, judge_model, judge_tokenizer):
        self.config = config
        self.judge_model = judge_model
        self.judge_tokenizer = judge_tokenizer
        self.task_type = config.rl_task_type
        self.template_name = getattr(config, 'rl_task', None)
        self.add_gt = getattr(config, 'add_gt', False)
        self.prompt_builder = PromptBuilder(config.model_mol_type)
        
        # Map task types to reward functions
        self.reward_map = {
            "cap2mol": self._reward_cap2mol,
            "mol2cap": self._reward_mol2cap,
            "equa2prod": self._reward_equa2prod,
            "prod2equa": self._reward_prod2equa,
            "class2mol": self._reward_class2mol,
            "mol2class": self._reward_mol2class
        }
    
    def __call__(self, completions, **kwargs):
        """Compute rewards for completions"""
        with torch.no_grad():
            reward_func = self.reward_map.get(self.task_type)
            if reward_func:
                return reward_func(completions, **kwargs)
            else:
                raise ValueError(f"Unknown task type: {self.task_type}")
    
    def _compute_logprob_rewards(self, prompts):
        """Compute log probability rewards using judge model"""
        scores = compute_reward_llama3(self.judge_model, self.judge_tokenizer, prompts)
        return scores
    
    def _reward_cap2mol(self, completions, **kwargs):
        generated_captions = [c[-1]['content'].strip() for c in completions]
        molecules = kwargs[self.config.model_mol_type]
        data = Dataset.from_dict({
            self.config.model_mol_type: molecules,
            "desc": generated_captions
        })
        if self.add_gt:
            tru_desc = kwargs["desc"]
            scores = compute_bleu_nltk_batch(self.judge_tokenizer, generated_captions, tru_desc, verbose=False)
            matches = 0
            for k in scores:
                matches += np.array(scores[k])
            matches = list(matches/6)
        else:
            matches = [0] * len(generated_captions)
        dialog_data = self.prompt_builder.build_prompts(data, self.template_name, is_generation=False)
        log_reward = self._compute_logprob_rewards(dialog_data["prompt"])
        return [log_reward[i] + matches[i]*3 for i in range(len(log_reward))]
    
    def _reward_mol2cap(self, completions, **kwargs):
        generated_molecules = [c[-1]['content'].strip() for c in completions]
        is_mol = []
        for mol in generated_molecules:
            is_mol.append(int(identify_smiles_components(mol)==1))
        descriptions = kwargs["desc"]
        data = Dataset.from_dict({
            self.config.model_mol_type: generated_molecules,
            "desc": descriptions
        })
        if self.add_gt:
            tru_mol = kwargs[self.config.model_mol_type]
            scores = create_score_indie(tru_mol, generated_molecules, self.config.model_mol_type, verbose=False)
            matches = 0
            for k in ["bleu", "maccs_similarity", "rdk_similarity", 'morgan_similarity']:
                matches += np.array(scores[k])
            matches = list(matches/4)
        else:
            matches = [0] * len(generated_molecules)
        dialog_data = self.prompt_builder.build_prompts(data, self.template_name, is_generation=False)
        log_reward = self._compute_logprob_rewards(dialog_data["prompt"])
        return [is_mol[i]*4 + log_reward[i] + matches[i]*2 for i in range(len(log_reward))]

    def _reward_equa2prod(self, completions, **kwargs):
        generated_equa = [c[-1]['content'].strip() for c in completions]
        is_chem = []
        for mol in generated_equa:
            is_chem.append(int(identify_smiles_components(mol)>=1))
        if self.add_gt:
            tru_equa = kwargs["equa"]
            scores = create_score_indie(tru_equa, generated_equa, self.config.model_mol_type, verbose=False)
            matches = 0
            for k in ["bleu", "maccs_similarity", "rdk_similarity", 'morgan_similarity']:
                matches += np.array(scores[k])
            matches = list(matches/4)
        else:
            matches = [0] * len(generated_equa)
        data = Dataset.from_dict({
            "prod": kwargs["prod"],
            "equa": generated_equa
        })
        dialog_data = self.prompt_builder.build_prompts(data, self.template_name, is_generation=False)
        log_reward = self._compute_logprob_rewards(dialog_data["prompt"])
        return [log_reward[i] + is_chem[i] + matches[i] for i in range(len(log_reward))]
    
    def _reward_prod2equa(self, completions, **kwargs):
        generated_prod = [c[-1]['content'].strip() for c in completions]
        if self.add_gt:
            tru_prod = kwargs["prod"]
            scores = create_score_indie(tru_prod, generated_prod, self.config.model_mol_type, verbose=False)
            matches = 0
            for k in ["bleu", "maccs_similarity", "rdk_similarity", 'morgan_similarity']:
                matches += np.array(scores[k])
            matches = list(matches/4)
        else:
            matches = [0] * len(generated_prod)
        is_prod = []
        for mol in generated_prod:
            is_prod.append(int(identify_smiles_components(mol)==1))
        data = Dataset.from_dict({
            "prod": generated_prod,
            "equa": kwargs["equa"]
        })
        dialog_data = self.prompt_builder.build_prompts(data, self.template_name, is_generation=False)
        log_reward = self._compute_logprob_rewards(dialog_data["prompt"])
        return [log_reward[i] + is_prod[i] + matches[i] for i in range(len(log_reward))]
    
    def _reward_class2mol(self, completions, **kwargs):
        generated_class = [c[-1]['content'].strip() for c in completions]
        molecules = kwargs[self.config.model_mol_type]
        data = Dataset.from_dict({
            self.config.model_mol_type: molecules,
            "class": generated_class
        })
        dialog_data = self.prompt_builder.build_prompts(data, self.template_name, is_generation=False)
        return self._compute_logprob_rewards(dialog_data["prompt"])

    def _reward_mol2class(self, completions, **kwargs):
        generated_molecule = [c[-1]['content'].strip() for c in completions]
        data = Dataset.from_dict({
            self.config.model_mol_type: generated_molecule,
            "class": kwargs["class"]
        })
        dialog_data = self.prompt_builder.build_prompts(data, self.template_name, is_generation=False)
        return self._compute_logprob_rewards(dialog_data["prompt"])

def setup_judge_model(config, device):
    """Setup judge model for reward computation"""
    judge_tokenizer = setup_tokenizer(config)
    judge_tokenizer.padding_side = "left"
    
    judge_model = AutoModelForCausalLM.from_pretrained(
        config.base_model_name,
        torch_dtype=torch.float16,
        cache_dir=config.cache_dir
    )
    
    judge_model, _ = load_lora_hist(
        config.judge_dir,
        judge_model,
        cache_dir=config.cache_dir
    )
    judge_model.to(device).eval()
    
    return judge_model, judge_tokenizer

@with_telegram_notifications
def main():
    parser = get_base_parser()
    args = parser.parse_args()
    config = load_config(args.config, args.opts)
    
    # Initialize accelerator
    accelerator = Accelerator()
    
    setup_experiment(config, accelerator.process_index, "molgen-rl")
    
    # Load data
    data_loader = DataLoader()
    if hasattr(config, 'dataset_tasks') and config.dataset_tasks:
        train_data = data_loader.load_datasets_with_tasks(
            config.dataset_tasks,
            getattr(config, 'dataset_limits', {}),
            getattr(config, 'dataset_processing', {})
        )
        if "SELFIES" not in train_data.column_names and "SMILES" in train_data.column_names and config.model_mol_type=="SELFIES":
            train_data = smiles2selfies(train_data)
        # Build prompts per dataset
        prompt_builder = PromptBuilder(config.model_mol_type)
        train_data = prompt_builder.build_prompts_per_dataset(train_data, is_generation=True)
    else:
        train_data = data_loader.load_multiple_datasets(
            config.target_datasets,
            getattr(config, 'dataset_limits', {}),
            getattr(config, 'dataset_processing', {})
        )
        if "SELFIES" not in train_data.column_names and "SMILES" in train_data.column_names and config.model_mol_type=="SELFIES":
            train_data = smiles2selfies(train_data)
        # Build prompts globally
        prompt_builder = PromptBuilder(config.model_mol_type)
        train_data = prompt_builder.build_prompts(train_data, config.tasks, is_generation=True)
    
    # Setup tokenizer
    tokenizer = setup_tokenizer(config)
    tokenizer.padding_side = "left"
    
    # Setup model
    policy_model, lora_hist = setup_model(config)
    model = setup_lora(policy_model, config)
    
    # Setup judge model
    judge_model, judge_tokenizer = setup_judge_model(config, accelerator.device)
    
    # Training arguments
    training_args = GRPOConfig(
        output_dir=os.path.join(config.exp_save_dir, "hf"),
        save_strategy="no",
        num_train_epochs=config.epochs,
        per_device_train_batch_size=config.batch_size,
        num_generations=config.num_generations,
        gradient_accumulation_steps=config.gradient_accumulation_steps,
        learning_rate=config.learning_rate,
        weight_decay=config.weight_decay,
        max_grad_norm=config.max_grad_norm,
        warmup_ratio=config.warmup_ratio,
        lr_scheduler_type=config.lr_scheduler_type,
        logging_steps=config.logging_steps,
        save_steps=config.save_steps,
        report_to="wandb" if accelerator.is_main_process else None,
        remove_unused_columns=False,
        ddp_find_unused_parameters=False,
        bf16=True,
        dataloader_pin_memory=False,
        dataloader_num_workers=2,
        top_p=getattr(config, 'top_p', 0.9),
        temperature=getattr(config, 'temperature', 0.7),
        top_k=getattr(config, 'top_k', 50),
        beta=0.08,
        use_vllm=False,
        vllm_server_base_url=None,
    )
    
    # Setup reward function
    reward_func = RewardFunction(config, judge_model, judge_tokenizer)
    
    # Clean data
    unwanted_cols = ["chosen", "rejected", "completion", "messages"]
    cols_to_remove = [col for col in unwanted_cols if col in train_data.column_names]
    if cols_to_remove:
        train_data = train_data.remove_columns(cols_to_remove)
    datasets.disable_progress_bar()
    
    # Trainer
    trainer = GRPOTrainer(
        model=model,
        processing_class=tokenizer,
        reward_funcs=reward_func,
        args=training_args,
        train_dataset=train_data
    )
    
    # Train
    trainer.train()
    
    # Save
    save_experiment(config, model, lora_hist, accelerator.process_index)

if __name__ == "__main__":
    main()