import os
import math
import random
import pathlib
import json
import pandas as pd
from typing import List
import argparse
import torch
from torch import nn
from datasets import load_from_disk, load_dataset
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoConfig, AutoModelForCausalLM
import trlx
from trlx.data.configs import TRLConfig

import sys
sys.path.append("../../proxy")
from sft.scenario_datasets import RolloutDataset

def load_model(model: str, load_path: str):
    config = AutoConfig.from_pretrained(model, num_labels=1)
    model = AutoModelForSequenceClassification.from_pretrained(model, config=config)
    model.load_state_dict(torch.load(load_path), strict=True)
    return model

def get_prompt_dataset(prompts: List[str], completions: List[str], max_length: int) -> List[str]:
    formatted_prompts = []
    formatted_completions = []
    for i in tqdm(range(len(prompts))):
        tmp = tokenizer(
            prompts[i],
            truncation=True,
            max_length=max_length)["input_ids"]
        # if len(tmp) >= max_length:
        #     continue
        tmp = tokenizer.decode(
            tmp,
            skip_special_tokens=True, clean_up_tokenization_spaces=True
        )
        if not tmp:
            continue
        formatted_prompts.append(tmp)
        formatted_completions.append(completions[i])
    return formatted_prompts, formatted_completions


def create_reward_fn():
    rw_tokenizer = AutoTokenizer.from_pretrained(args.reward_model)
    rw_tokenizer.truncation_side = "left"
    gold_tokenizer = AutoTokenizer.from_pretrained(args.gold_model)
    gold_tokenizer.truncation_side = "left"

    if os.environ.get("RANK", "0") == "0":

        rw_model = load_model(args.reward_model, args.reward_checkpoint_path).eval()
        gold_model1 = load_model(args.gold_model, args.gold_checkpoint_path + 'deberta_large_gold1.pkl').eval()
        gold_model2 = load_model(args.gold_model, args.gold_checkpoint_path + 'deberta_large_gold2.pkl').eval()

        rw_device = gold_device = torch.cuda.device_count() - 1 # torch.device("cuda:{}".format(7))  # set reward model device
        rw_model.to(rw_device)
        print("loaded reward model")
        gold_model1.to(gold_device)
        gold_model2.to(gold_device)
        print("loaded gold model")
        print(rw_model.device, gold_model2.device)

        @torch.no_grad()
        def get_scores(samples: List[str], model, tokenizer, device) -> torch.tensor:
            scores_list = []
            batch_size = 16 * 8
            for i in range(0, len(samples), batch_size):
                sub_samples = samples[i : i + batch_size]
                encodings_dict = tokenizer(
                    sub_samples,
                    truncation=True,
                    max_length=config.train.seq_length,
                    padding=True,
                    return_tensors="pt",
                )
                input_ids = encodings_dict["input_ids"].to(device)
                attn_masks = encodings_dict["attention_mask"].to(device)
                sub_scores = model(input_ids=input_ids, attention_mask=attn_masks).logits.reshape((-1))
                scores_list.append(sub_scores)
            scores = torch.cat(scores_list, dim=0)
            return scores
        
        @torch.no_grad()
        def get_scores_ensemble(samples: List[str], models, tokenizer, device) -> torch.tensor:
            scores_list = [[] for _ in range(len(models))]
            batch_size = 16 * 8
            for i in range(0, len(samples), batch_size):
                sub_samples = samples[i : i + batch_size]
                encodings_dict = tokenizer(
                    sub_samples,
                    truncation=True,
                    max_length=config.train.seq_length,
                    padding=True,
                    return_tensors="pt",
                )
                input_ids = encodings_dict["input_ids"].to(device)
                attn_masks = encodings_dict["attention_mask"].to(device)
                for j in range(len(models)):
                    sub_scores = models[j](input_ids=input_ids, attention_mask=attn_masks).logits.reshape((-1))
                    scores_list[j].append(sub_scores)
            scores = []
            for j in range(len(models)):
                scores.append(torch.cat(scores_list[j], dim=0))
            return scores

        def reward_fn(samples: List[str], prompts: List[str], outputs: List[str], evaluate: bool=False, kl: float=0, **kwargs) -> torch.tensor:
            scores = (get_scores(samples, rw_model, rw_tokenizer, rw_device) - RM_STATS["mean"]) / RM_STATS["std"]

            if evaluate:
                gold_scores = get_scores_ensemble(samples, [gold_model1, gold_model2], gold_tokenizer, gold_device)

                data = {
                    "samples": samples,
                    "new_scores": scores.detach().cpu().numpy(),
                    "new_gold_scores1": gold_scores[0].detach().cpu().numpy(),
                    "new_gold_scores2": gold_scores[1].detach().cpu().numpy(),
                }
                df = pd.DataFrame(data)
                df['kl'] = kl
                # append to csv file
                csv_name = f"scores_{args.policy_model}_{args.reward_checkpoint_path.split('/')[-1][:-4]}_{args.num_train}_{config.model.num_layers_unfrozen}_"
                csv_name += f"{config.method.gen_kwargs['max_new_tokens']}_{config.train.batch_size}_"
                csv_name += f"{config.optimizer.kwargs['lr']}_{config.method.init_kl_coef}_{args.reward_model.split('/')[-1]}_{args.seed}.csv"
                df.to_csv(csv_name, escapechar='\\', mode="a", header=False, index=False)

            return scores

    else:
        reward_fn = True
    
    return reward_fn

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset_path", type=str, default="../../data/moral_stories_dataset.hf", help="path to load HuggingFace dataset of continuation pairs")
    parser.add_argument("--reward_model", type=str, default="roberta-large", help="reward model")
    parser.add_argument("--gold_model", type=str, default="microsoft/deberta-v3-large", help="gold model")
    parser.add_argument("--reward_checkpoint_path", type=str, required=True, help="path to load reward model weights")
    parser.add_argument("--gold_checkpoint_path", type=str, required=True, help="path to load gold model weights")
    parser.add_argument("--policy_model", type=str, default="llama", help="policy model")
    parser.add_argument("--num_train", type=int, default=0, help="number of training examples to use, 0 for all")
    parser.add_argument("--seed", type=int, default=42, help="seed")

    args = parser.parse_args()

    seed_val = args.seed
    random.seed(seed_val)
    torch.manual_seed(seed_val)
    torch.cuda.manual_seed_all(seed_val)
    
    config_path = pathlib.Path(__file__).parent.joinpath("configs/ppo_config.yml")
    config = TRLConfig.load_yaml(config_path)

    if args.policy_model == "gpt2-large":
        config.model.model_path = "gpt2-large"
        config.tokenizer.tokenizer_path = 'gpt2-large'
        config.train.batch_size = 16
        config.optimizer.kwargs["lr"] = 1e-6
    elif args.policy_model == "gpt2-xl":
        config.model.model_path = "gpt2-xl"
        config.tokenizer.tokenizer_path = 'gpt2-xl'
        config.train.batch_size = 16
        config.optimizer.kwargs["lr"] = 1e-6
    elif args.policy_model == "llama":
        config.model.model_path = "llama-7b"
        config.tokenizer.tokenizer_path = 'llama-7b'
        config.train.batch_size = 8
        config.method.chunk_size = 32
        config.optimizer.kwargs["lr"] = 1e-6
    else:
        raise ValueError("invalid policy model")
    
    config.model.model_path = "../sft/" + config.model.model_path
    
    config.model.num_layers_unfrozen = 2
    config.method.init_kl_coef = 0.01
    config.method.ppo_epochs = 4
    config.method.target = None
    config.train.checkpoint_dir = f"checkpoints/{args.policy_model}_{args.reward_checkpoint_path.split('/')[-1][:-4]}_{args.num_train}_{config.model.num_layers_unfrozen}_{config.method.gen_kwargs['max_new_tokens']}_{config.train.batch_size}_{config.optimizer.kwargs['lr']}_{config.method.init_kl_coef}_{args.reward_model.split('/')[-1]}_{args.seed}"

    tokenizer = AutoTokenizer.from_pretrained(config.tokenizer.tokenizer_path)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "left"
    tokenizer.truncation_side = "left"
    max_length_input = (config.train.seq_length - config.method.gen_kwargs["max_new_tokens"])

    metadata = json.load(open("../../proxy/normalization_coeffs.json"))
    RM_STATS = metadata[args.reward_checkpoint_path]

    def get_samples(dataset):
        samples = []
        c = 0
        for sample in dataset:
            sample_list = []
            current_sample = ""
            try:
                for chunk in sample.split("\n\nHuman: "):
                    if not chunk: continue
                    question, response = chunk.split("\n\nAssistant: ")
                    current_sample += "\n\nHuman: " + question + "\n\nAssistant: " 
                    sample_list.append((current_sample.rstrip(), response))
                    current_sample += response
            except:
                c += 1
                continue
            samples.append(sample_list[-1])
        print("skipped", c, "samples")
        random.shuffle(samples)
        return samples
    
    train_set = get_samples(RolloutDataset("train"))
    val_set = get_samples(RolloutDataset("test"))
    train_posts, train_continuations = zip(*train_set)
    val_posts, val_continuations = zip(*val_set)

    train_prompts, train_continuations = get_prompt_dataset(train_posts, train_continuations, max_length_input)
    if args.num_train > 0:
        assert args.num_train >= 512
        train_prompts = train_prompts[:args.num_train]
    val_prompts, val_continuations = get_prompt_dataset(val_posts, val_continuations, max_length_input)
    post_continuation_dict = {}
    for i in range(len(train_prompts)):
        post_continuation_dict[train_prompts[i].strip()] = train_continuations[i]
    for i in range(len(val_prompts)):
        post_continuation_dict[val_prompts[i].strip()] = val_continuations[i]
    
    num_eval = 1024
    val_prompts = val_prompts[:num_eval] + train_prompts[:num_eval]

    print("loaded dataset, LEN:", len(train_prompts))
    print(train_prompts[-1])
    
    reward_fn = create_reward_fn()

    trainer = trlx.train(
        config.model.model_path,
        reward_fn=reward_fn,
        prompts=train_prompts,
        eval_prompts=val_prompts,
        config=config,
    )

    print("done training")
