import argparse
import os
import subprocess
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import transformers
from preference_datasets import get_dataset
from tqdm import tqdm
import torch
import numpy as np
import json
from peft import PeftModel
import wandb

def main(args):
    gpus_ = ','.join(args.avaliable_gpus[:4])
    model_name = args.model_id.split('/')[-1]
    exp_name = f'trained_model_evaluation_{model_name}_{args.dataset}_{args.method_name}_{args.exp_name}'
    num_response = 3
    tensor_parallel_size = 4 if len(args.avaliable_gpus) >= 4 else 2

    # if there are no files that end with safetensors in the model_path, we should convert
    if not any([file.endswith('safetensors') for file in os.listdir(args.model_path)]):
        best_model = transformers.AutoModelForCausalLM.from_pretrained(args.init_model_path, cache_dir=args.cache_dir, low_cpu_mem_usage=True, torch_dtype=torch.float32)
        peft_best_model = PeftModel.from_pretrained(best_model, os.path.join(args.model_path, 'adapter'))
        merged_model = peft_best_model.merge_and_unload()
        merged_model.save_pretrained(args.model_path)
    generated_data_path_trained = f'{args.cache_dir}/{exp_name}_generated_data.jsonl'

    cmd = f'CUDA_VISIBLE_DEVICES={gpus_} python generate_evaluate.py --dataset {args.dataset} --model_id {args.model_id} --original_data_dir {args.data_path} --num_response {num_response} --model_path {args.model_path} --tensor_parallel_size {tensor_parallel_size} --generated_data_path {generated_data_path_trained}'

    subprocess.run(cmd, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)

    exp_name = f'initial_model_evaluation_{model_name}_{args.dataset}_{args.method_name}_{args.exp_name}'
    generated_data_path_init = f'{args.cache_dir}/{exp_name}_generated_data.jsonl'

    cmd = f'CUDA_VISIBLE_DEVICES={gpus_} python generate_evaluate.py --dataset {args.dataset} --model_id {args.model_id} --original_data_dir {args.data_path} --num_response {num_response} --model_path {args.init_model_path} --tensor_parallel_size {tensor_parallel_size} --generated_data_path {generated_data_path_init}'

    subprocess.run(cmd, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    


    model_dict = {'hh': 'Ray2333/gpt2-large-harmless-reward_model', 'tldr': 'OpenAssistant/reward-model-deberta-v3-large', 'webgpt': 'OpenAssistant/reward-model-deberta-v3-large-v2', 'syntheticgpt': 'OpenAssistant/reward-model-deberta-v3-large-v2'}
    model_path = model_dict[args.dataset]
    rm_tokenizer = AutoTokenizer.from_pretrained(model_path, cache_dir=args.cache_dir)
    reward_model = AutoModelForSequenceClassification.from_pretrained(
                    model_path,
                    num_labels=1, torch_dtype=torch.bfloat16,
                    device_map=0, cache_dir=args.cache_dir)
    device = torch.device(f'cuda:{gpus_[0]}')
    reward_model.to(device)

    with open(generated_data_path_trained, 'r') as f:
        trained_data_points = [json.loads(line) for line in f]

    with open(generated_data_path_init, 'r') as f:
        init_data_points = [json.loads(line) for line in f]


    trained_data_points = trained_data_points[1:]
    init_data_points = init_data_points[1:]

    batch_size = 1 # Define your batch size
    trained_rewards = []
    for i in tqdm(range(0, len(trained_data_points), batch_size)):
        batch = trained_data_points[i:i + batch_size]
        prompts = [item[0] for item in batch]
        responses1 = [item[1] for item in batch]
        truncation_modes = [item[2] for item in batch]

        with torch.no_grad():
            if 'gpt2' in model_path:
                inputs1 = rm_tokenizer(prompts, responses1, return_tensors='pt', truncation=True, padding=False).to(device)
            else:
                inputs1 = rm_tokenizer(prompts, responses1, return_tensors='pt', truncation=True, padding=True, max_length=5000).to(device)
            
            rewards1 = reward_model(**inputs1).logits[:, 0].float().cpu().detach().numpy()
            trained_rewards.extend([tmp for tmp in rewards1])
    init_rewards = []
    for i in tqdm(range(0, len(init_data_points), batch_size)):
        batch = init_data_points[i:i + batch_size]
        prompts = [item[0] for item in batch]
        responses1 = [item[1] for item in batch]
        truncation_modes = [item[2] for item in batch]

        with torch.no_grad():
            if 'gpt2' in model_path:
                inputs1 = rm_tokenizer(prompts, responses1, return_tensors='pt', truncation=True, padding=False).to(device)
            else:
                inputs1 = rm_tokenizer(prompts, responses1, return_tensors='pt', truncation=True, padding=True, max_length=5000).to(device)
            
            rewards1 = reward_model(**inputs1).logits[:, 0].float().cpu().detach().numpy()
            init_rewards.extend([tmp for tmp in rewards1])
    
    trained_rewards = np.array(trained_rewards)
    init_rewards = np.array(init_rewards)
    win_rate = np.mean(trained_rewards >= init_rewards)
    print(f'Win rate: {win_rate}')
    trained_mean_reward = np.mean(trained_rewards)
    trained_std_reward = np.std(trained_rewards)
    init_mean_reward = np.mean(init_rewards)
    init_std_reward = np.std(init_rewards)
    print(f'Trained mean reward: {trained_mean_reward}, std: {trained_std_reward}')
    print(f'Init mean reward: {init_mean_reward}, std: {init_std_reward}')
    wandb.init(
            entity=None,
            project='direct-preference-optimization',
            name=f'evaluation_{model_name}_{args.dataset}_{args.method_name}_{args.exp_name}',
        )
    wandb.log({'win_rate': win_rate, 'trained_mean_reward': trained_mean_reward, 'trained_std_reward': trained_std_reward, 'init_mean_reward': init_mean_reward, 'init_std_reward': init_std_reward})
    wandb.finish()

    if args.delete:
        for file in os.listdir(args.model_path):
            if file != 'adapter':
                print(args.model_path, file)
                os.remove(os.path.join(args.model_path, file))

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Label data using a reward model.')
    parser.add_argument('--method_name', type=str, default='ours', help='The name of the method')
    parser.add_argument('--dataset', type=str, default='tldr', help='Dataset name')
    parser.add_argument('--data_path', type=str, default='.cache/ours_llama2-tldr-iteration-0.jsonl', help='Path to the data file')
    parser.add_argument('--cache_dir', type=str, default='.cache/', help='Cache directory for model and tokenizer')
    parser.add_argument('--model_id', type=str, default='meta-llama/Llama-2-7b-hf', help='The model id')
    parser.add_argument('--init_model_path', type=str, default='', help='The initial model path')
    parser.add_argument('--model_path', type=str, default='', help='The trained model path')
    parser.add_argument('--avaliable_gpus', type=str, default='0123', help='The GPU ids')
    parser.add_argument('--delete', type=int, default=1, help='Whether to delete the models')
    parser.add_argument('--exp_name', type=str, default='', help='The name of the experiment')

    args = parser.parse_args()
    main(args)

