import os
from dataclasses import dataclass, field
from typing import Optional
from accelerate import Accelerator
import torch
import random
from tqdm import tqdm
from transformers import HfArgumentParser
from transformers import T5Tokenizer, T5ForConditionalGeneration
from datasets import Dataset
from trl import AutoModelForSeq2SeqLMWithValueHead, PPOConfig, PPOTrainer2, set_seed
import numpy as np
import pandas as pd
import gc  
tqdm.pandas()
from peft import LoraConfig
import matplotlib.pyplot as plt
import re
from transformers import (
    LlamaForCausalLM,
    LlamaTokenizer,
    StoppingCriteria,
)
from peft import PeftModel, get_peft_model
os.environ["WANDB__SERVICE_WAIT"] = "3000"
os.environ["WANDB_PROJECT"] = "ppo_BioT5"


@dataclass
class ScriptArguments:
    log_with: Optional[str] = field(default='wandb', metadata={"help": "use 'wandb' to log with wandb"})
    disable_wandb: Optional[str] = field(default=False, metadata={'help': 'Whether to disable wandb or not.'})
    save_directory: Optional[str] = field(default='./rl_saved/')
    epochs: Optional[int] = field(default=10, metadata={'help': "Number of training epoches"})
    learning_rate: Optional[float] = field(default=5e-5, metadata={"help": "the learning rate"})
    mini_batch_size: Optional[int] = field(default=2, metadata={"help": "the PPO minibatch size"})
    batch_size: Optional[int] = field(default=32, metadata={"help": "the batch size"})
    load_in_8bit: Optional[bool] = field(default=True, metadata={"help": "loading model in 8 bit or bfloat16"})
    gradient_accumulation_steps: Optional[int] = field(default=1, metadata={"help": "the number of gradient accumulation steps"})
    early_stopping: Optional[bool] = field(default=True, metadata={"help": "whether to early stop"})
    target: Optional[float] = field(default=3, metadata={"help": "target kl divergence of adaptive control"})
    init_kl_coef: Optional[float] = field(default=0.0,metadata={"help": "Initial KL penalty coefficient (used for adaptive and linear control)"},)
    max_grad_norm: Optional[float] = field(default=1.0, metadata={"help": "Maximum gradient norm for gradient clipping"})
    wandb_name: Optional[str] = field(default='None', metadata={"help": "Name for this experiment"})
    base_model_name: Optional[str] = field(default='SFT_model', metadata={'help':"the path to the sft model; need to merge if using lora"})
    top_p: Optional[float] = field(default=0.9, metadata={'help':"the path to the sft model; need to merge if using lora"})
    top_k: Optional[int] = field(default=0, metadata={'help':"the path to the sft model; need to merge if using lora"})
    max_new_tokens: Optional[int] = field(default=2560, metadata={'help':"the path to the sft model; need to merge if using lora"})


parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]
base_model_name = script_args.base_model_name
print('base model: ', base_model_name)
os.makedirs(os.path.join(script_args.save_directory, script_args.wandb_name), exist_ok=True)


config = PPOConfig(
    model_name=base_model_name,
    learning_rate=script_args.learning_rate,
    log_with=script_args.log_with,
    mini_batch_size=script_args.mini_batch_size,
    batch_size=script_args.batch_size,
    gradient_accumulation_steps=script_args.gradient_accumulation_steps,
    early_stopping=script_args.early_stopping,
    target=script_args.target,
    max_grad_norm=script_args.max_grad_norm,
    optimize_cuda_cache=True,
    init_kl_coef=script_args.init_kl_coef,
    tracker_project_name='ppo_BioT5',
    tracker_kwargs={"wandb":{"name":script_args.base_model_name.replace('/','_')+'_RL_'+script_args.wandb_name}},
)
accelerator = Accelerator()
process_id = Accelerator().local_process_index 
gpu_id = process_id
print('process: {}'.format(process_id))
set_seed(8888)
current_device = Accelerator().local_process_index
print(current_device)


lora_config = LoraConfig(
    r=16, 
    lora_alpha=32, 
    lora_dropout=0.05,
    bias="none",
    task_type="SEQ_2_SEQ_LM",
)
tokenizer = T5Tokenizer.from_pretrained("QizhiPei/biot5-plus-base-chebi20")
model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(
    base_model_name,
    peft_config=lora_config,
    device_map=gpu_id,
)
tokenzier_vocab_size = len(tokenizer)


def preprocess_data(examples, max_input_length=512):
    inputs = examples['description']
    model_inputs = tokenizer(inputs, max_length=max_input_length, padding='max_length', truncation=True)
    return model_inputs


from datasets import load_dataset
from torch.utils.data import DataLoader
train_data = load_dataset('json', data_files='data/task1_chebi20_text2mol_train.json', field='Instances')['train']
train_dataloader = DataLoader(train_data,shuffle=True, batch_size=script_args.batch_size)
train_dataloader = accelerator.prepare(train_dataloader)


import rdkit
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import Descriptors
from rdkit.Chem import QED
from rdkit.Chem import GraphDescriptors
from rdkit.Chem import Lipinski
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.warning')
RDLogger.DisableLog('rdApp.error')
from nltk.translate.bleu_score import corpus_bleu
import selfies as sf


def compute_fingerprints(smiles_list, ground_truth):
    valid_smiles, fingerprints, bleus = [], []
    ground_truth_ = sf.decoder(ground_truth)
    for smiles in smiles_list:
        if len(smiles) > 0:
            try:
                real_smiles = sf.decoder(smiles)
                mol = Chem.MolFromSmiles(real_smiles)
                bleu = corpus_bleu([[[c for c in ground_truth_]]], [[c for c in real_smiles]])
                if mol: 
                    bleus.append(bleu)
                    valid_smiles.append(smiles)
                    fingerprint = AllChem.GetMorganFingerprint(mol, radius=2)
                    fingerprints.append(fingerprint)
            except:
                continue
    return valid_smiles, fingerprints, bleus


def pairwise_tanimoto_similarities(smiles_list, ground_truth):
    valid_smiles, fingerprints, bleus = compute_fingerprints(smiles_list, ground_truth)
    num_fps = len(fingerprints)
    if len(valid_smiles) <= 1:
        return valid_smiles, np.zeros((1), dtype=np.float32), bleus
    mtx = np.zeros((len(valid_smiles)-1), dtype=np.float32)
    for i in range(num_fps-1):
        sim = rdkit.DataStructs.TanimotoSimilarity(fingerprints[len(valid_smiles)-1], fingerprints[i])
        mtx[i] = sim
    return valid_smiles, mtx, bleus


def return_reward_matrx(response_tensors, ground_truths):
    rewards_tensors = []
    for e_, response in enumerate(response_tensors):
        rewards_tensor = torch.zeros(response.shape[0]).to(accelerator.device)
        smi, valid_list = "", []
        for e,r in enumerate(response):
            tok = tokenizer.convert_tokens_to_string([tokenizer.convert_ids_to_tokens(r.item())])
            if tok.count('<eom>') >=1:
                smi = smi.replace(' ','')
                mols, r, mtx, bleus = pairwise_tanimoto_similarities(valid_list + [smi],ground_truths[e_])
                if len(mols) == len(valid_list)+1:
                    rewards_tensor[e] = (bleus[-1]**0.5-np.max(mtx)**2) if len(mols) > 1 else bleus[-1]**0.5
                    rewards_tensor[e] = rewards_tensor[e] if rewards_tensor[e] > 0 else 0
                    rewards_tensor[e] = rewards_tensor[e] * 8
                valid_list = mols
                if e + 1 < rewards_tensor.shape[0]:
                    rewards_tensor[e+1] = -1 # symbol for multi-stage learning                       
                smi = ""
            if tok.count('<bom>') >=1:
                smi = ""
            else:
                smi += tok
        rewards_tensors.append(rewards_tensor)
    return rewards_tensors


optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=config.learning_rate)
ppo_trainer = PPOTrainer2(
    config, model, tokenizer=tokenizer, optimizer=optimizer
)


generation_kwargs = {
    "max_new_tokens": script_args.max_new_tokens,
    "min_new_tokens": 2559,
    'min_length': -1, 
    "top_k": script_args.top_k,
    "top_p": script_args.top_p, 
    "do_sample": True,
    "temperature": 0.7,
    "pad_token_id": tokenizer.eos_token_id,
    "begin_suppress_tokens": [tokenizer.eos_token_id],
}


print("Training........")
model.gradient_checkpointing_disable()
model.pretrained_model.config.use_cache = True


epochs = script_args.epochs
mean_scores = []
std_scores = []
save_data = {
    'reward_mean': [],
    'reward_std': [],
}
task_definition = 'Definition: You are given a molecule description in English. Your job is to generate the molecule SELFIES that fits the description.\n\n'
for epoch in range(epochs):
    for i, batch in enumerate(train_dataloader):
        print('epoch {}, batch {}'.format(epoch, i))
        task_inputs = [f'Now provide a set of molecules -\nInput: {get_input}\nOutput: ' for get_input in batch['input']]
        batch['description'] =  [task_definition + task_input for task_input in task_inputs]
        batch['smiles'] = [ground_mols.replace('<bom>','').replace('<eom>','') for ground_mols in batch['output'][0]]
        input_ = preprocess_data(batch)
        query_tensor = input_["input_ids"]
        query_tensor = [torch.tensor(q).to(gpu_id) for q in query_tensor]


        model.gradient_checkpointing_disable()
        model.pretrained_model.config.use_cache = True        
        with torch.no_grad():
            response_tensors = ppo_trainer.generate(query_tensor, **generation_kwargs) 

        
        full_responses = tokenizer.batch_decode(response_tensors)
        full_responses_clean, rewards = [], []        
        rewards_tensors = return_reward_matrx(response_tensors, batch['smiles'])        
        for e, (response, rewards_tensor) in enumerate(zip(full_responses,rewards_tensors)):
            rewards.append(torch.sum(rewards_tensor[rewards_tensor!=-1]).item())                
            full_responses_clean.append(response)
                
                
        batch['query'] = batch['description']
        batch['response'] = full_responses_clean
        print("iter {}, batch {}: mean score: {}".format(epoch, i, torch.mean(torch.tensor(rewards)).item()))
        model.gradient_checkpointing_enable()
        model.pretrained_model.config.use_cache = False
        ppo_trainer.config.batch_size = len(query_tensor)
        stats = ppo_trainer.step(query_tensor, response_tensors, rewards_tensors)
        ppo_trainer.log_stats(stats, batch, rewards)


        accelerator.wait_for_everyone()        
        if ppo_trainer.accelerator.is_main_process and i % 25 == 0 and i!=0:
            save_path = os.path.join(script_args.save_directory, script_args.wandb_name, 'batch_{}'.format(i))
            ppo_trainer.save_pretrained(save_path)
            print("iter {}, batch {}: model saved".format(epoch, i))
        gc.collect()
        torch.cuda.empty_cache()
        accelerator.wait_for_everyone()


    if ppo_trainer.accelerator.is_main_process:
        save_path = os.path.join(script_args.save_directory, script_args.wandb_name, 'epoch_{}_batch_{}'.format(epoch, i))
        ppo_trainer.save_pretrained(save_path)
        print("iter {}, batch {}: model saved".format(epoch, i))
        gc.collect()
        torch.cuda.empty_cache()
