import os
import math
import random
import pathlib
import json
import pandas as pd
from typing import List
import argparse
import torch
from torch import nn
from datasets import load_from_disk, load_dataset
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoConfig, AutoModelForCausalLM
from accelerate import Accelerator
import pickle
import pandas as pd

import sys
sys.path.append("../reward")
from sft.scenario_datasets import RolloutPreferencePairsDataset
from generate_bon_completions import generate_continuations
import fire
import datetime

def get_prompt_dataset(prompts: List[str], completions: List[str], max_length: int, tokenizer: AutoTokenizer) -> List[str]:
    formatted_prompts = []
    formatted_completions = []
    for i in tqdm(range(len(prompts))):
        tmp = tokenizer(
            prompts[i],
            truncation=True,
            max_length=max_length)["input_ids"]
        if len(tmp) >= max_length:
            continue
        tmp = tokenizer.decode(
            tmp,
            skip_special_tokens=True, clean_up_tokenization_spaces=True
        )
        if not tmp:
            continue
        formatted_prompts.append(tmp)
        formatted_completions.append(completions[i])
    return formatted_prompts, formatted_completions


def main(model_name = "EleutherAI/pythia-1.4b-deduped", split = "train", temperature = 1, max_input_length = 384, max_completion_length = 128, outer_batch_size = 1, inner_batch_size = 32, num_prompts = 0, num_completions_per_prompt = 1024, gpu_num=None, total_gpus=None):
    
    now = datetime.datetime.now()
    formatted_now = now.strftime("%Y-%m-%d-%H-%M-%S")
    # Create an accelerator object
    accelerator = Accelerator()

    seed_val = 1
    random.seed(seed_val)

    if model_name == "gpt2":
        model_path = "sft/gpt-sft-long-hh-checkpoint"
        tokenizer_path = 'gpt2'
    elif model_name == "pythia-410m":
        model_path = "pythia-410m-deduped"
        tokenizer_path = "EleutherAI/pythia-410m-deduped"
    elif model_name == "pythia-1b":
        model_path = "pythia-1"
        tokenizer_path = "EleutherAI/pythia-1.4b-deduped"
        outer_batch_size = 16
        inner_batch_size = 2
    elif model_name == "pythia-6b":
        model_path = "pythia-6"
        tokenizer_path = "EleutherAI/pythia-6.9b-deduped"
    elif model_name == "pythia-12b":
        model_path = "pythia-12b"
        tokenizer_path = "EleutherAI/pythia-12b-deduped"
    elif model_name == "llama-7b":
        model_path = "llama-7b"
        tokenizer_path = '/data/private_models/xx_models/llama/llama_hf_weights_v1.1/llama-7b'
        outer_batch_size = 16
        inner_batch_size = 2
    else:
        raise ValueError("invalid policy model")

    tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)   
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "left"
    tokenizer.truncation_side = "left"
    

    model = AutoModelForCausalLM.from_pretrained(model_path, pad_token_id=tokenizer.eos_token_id).to("cuda")
    def get_samples(dataset):
        samples = []
        c = 0
        for sample in dataset:
            sample_list = []
            current_sample = ""
            try:
                for chunk in sample.split("\n\nHuman: "):
                    if not chunk: continue
                    question, response = chunk.split("\n\nAssistant: ")
                    current_sample += "\n\nHuman: " + question + "\n\nAssistant: " 
                    sample_list.append((current_sample.rstrip(), response))
                    current_sample += response
            except:
                c += 1
                continue
            samples.append(sample_list[-1])
        accelerator.print("skipped", c, "samples")
        random.shuffle(samples)
        return samples
    
    train_set = get_samples(RolloutPreferencePairsDataset(split))
    train_posts, train_continuations = zip(*train_set)

    train_prompts, train_continuations = get_prompt_dataset(train_posts, train_continuations, max_input_length, tokenizer)
    print("train_prompts length:", len(train_prompts))

    start, end = 0, 0
    num_prompts = num_prompts or len(train_prompts)
    if total_gpus:
        start = gpu_num * num_prompts // total_gpus
        end = (gpu_num + 1) * num_prompts // total_gpus
        print("start:", start)
        print("end:", end)
    if total_gpus:
        train_prompts = train_prompts[start:end]
    else:
        train_prompts = train_prompts[:num_prompts]
    accelerator.print("loaded dataset, LEN:", len(train_prompts))
    accelerator.print(train_prompts[-1])


    bon_generations = []
    for i in tqdm(range(len(train_prompts) // outer_batch_size)):

        # accelerator.print(f"prompt{i+start}: {prompt}")
        prompt = train_prompts[i*outer_batch_size:(i+1)*outer_batch_size]

        continuations = generate_continuations(model=model, batch_size=inner_batch_size, num_completions=num_completions_per_prompt, temperature=temperature, max_new_tokens=max_completion_length, max_input_length=max_input_length, tokenizer=tokenizer, prompt=prompt, accelerator=accelerator)
        bon_generations.extend(continuations)

        if i % 8 == 0 or i == len(train_prompts) - 1:

            print(f"saving generations {formatted_now}")
            if total_gpus:
                out_filename = f"pairwise_generations_{split}_{model_name}_{gpu_num}.csv"
            else:
                out_filename = f"pairwise_generations_{split}_{model_name}.csv"

            data = {
                "samples": tokenizer.batch_decode(bon_generations, skip_special_tokens=True, clean_up_tokenization_spaces=True)
            }
            df = pd.DataFrame(data)
            # append to csv file
            df.to_csv(out_filename, escapechar='\\', mode="a", header=False, index=False)

            bon_generations = []
    

if __name__ == "__main__":
    fire.Fire(main)