from __future__ import annotations

import json
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import *
from typing import Optional

import click
import datasets
import numpy as np
import pandas as pd
import torch
from accelerate import Accelerator
from alpaca_eval.decoders.huggingface_local import ListDataset
from datasets import load_dataset, load_from_disk
from tqdm import tqdm
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    HfArgumentParser,
    pipeline,
)


@click.command()
@click.option("--ds_path", type=str, required=True)
@click.option("--generation_model_path", type=str, required=True)
@click.option("--reward_model_path", type=str, required=True)
@click.option("--output_dir", type=str, required=True)
@click.option("--regenrate_reward_threshold", type=float, default=3.0)
@click.option("--max_new_tokens", type=int, default=2048)
@click.option("--do_sample", type=bool, default=True)
@click.option("--batch_size", type=int, default=4)
@click.option("--n_generation_per_prompt", type=int, default=4)
@click.option("--repetition_penalty", type=float, default=1.2)
@click.option("--temperature_scaler", type=float, default=2.0)
@click.option("--top_p", type=float, default=1.0)
@click.option("--num_beams", type=int, default=8)
def main(
    ds_path: str,
    generation_model_path: str,
    reward_model_path: str,
    output_dir: str,
    regenrate_reward_threshold: float = 3.0,
    max_new_tokens: int = 2048,
    do_sample: bool = True,
    batch_size: int = 4,
    n_generation_per_prompt: int = 4,
    repetition_penalty: float = 1.2,
    temperature_scaler: float = 2.0,
    top_p: float = 1.0,
    num_beams: int = 8,
):
    try:
        ds = datasets.load_dataset(ds_path)
    except:
        ds = datasets.load_from_disk(ds_path)
    if isinstance(ds, datasets.DatasetDict):
        ds = datasets.concatenate_datasets(list(ds.values()))

    print("size of original ds: ", len(ds))

    assert (
        "chosen_score" in ds.features and "rejected_score" in ds.features
    ), "Dataset must contain chosen_score and rejected_score columns"

    # as the chosen score is always greater than rejected score
    # we only need to check the chosen score
    valid_indices = [
        i
        for i, x in enumerate(ds["chosen_score"])
        if x > regenrate_reward_threshold
    ]
    invalid_indices = [i for i in range(len(ds)) if i not in valid_indices]
    invalid_prompts = ds.select(invalid_indices)["prompt"]

    print("size of invalid ds before regeneration: ", len(invalid_indices))

    updated_sample_list: List[dict] = []
    temperature = 1.0
    prompts_ = invalid_prompts
    while len(updated_sample_list) < len(invalid_prompts):
        temperature *= temperature_scaler
        if temperature > temperature_scaler**3:
            print("Temperature too high, please check the prompts")
            break
        print(
            f"Generating {len(prompts_)} prompts with scaled temperature {temperature}"
        )

        # generation
        tokenizer = AutoTokenizer.from_pretrained(
            generation_model_path,
            padding_side="left",
            torch_dtype=torch.bfloat16,
            device_map="auto",
        )
        model = AutoModelForCausalLM.from_pretrained(
            generation_model_path, torch_dtype=torch.bfloat16, device_map="auto"
        ).eval()

        if batch_size > 1:
            original_order, prompts_ = zip(
                *sorted(enumerate(prompts_), key=lambda x: len(x[1]))
            )
            prompts = list(prompts_)

        if not tokenizer.pad_token_id:
            tokenizer.pad_token_id = tokenizer.eos_token_id
            tokenizer.pad_token = tokenizer.eos_token

        default_kwargs = dict(
            do_sample=do_sample,
            max_new_tokens=max_new_tokens,
            batch_size=batch_size,
            repetition_penalty=repetition_penalty,
            temperature=temperature,
            top_p=top_p,
            num_beams=num_beams,
            num_return_sequences=n_generation_per_prompt,
            eos_token_id=list(
                set([tokenizer.eos_token_id, tokenizer.pad_token_id])
            ),
        )

        print(f"Model memory: {model.get_memory_footprint() / 1e9} GB")
        print(default_kwargs)

        generation_pipeline = pipeline(
            "text-generation",
            model=model,
            tokenizer=tokenizer,
            **default_kwargs,
            trust_remote_code=True,
        )

        prompts_ds = ListDataset(prompts)
        completions = []
        for output in tqdm(
            generation_pipeline(
                prompts_ds,
                return_full_text=False,
                pad_token_id=tokenizer.pad_token_id,
            )
        ):
            completions.append([_["generated_text"] for _ in output])

        # release the model memory
        del generation_pipeline
        model.cpu()
        del model
        del tokenizer
        torch.cuda.empty_cache()

        # annotate the reward
        rm_tokenizer = AutoTokenizer.from_pretrained(
            reward_model_path,
        )
        rm_pipeline = pipeline(
            "sentiment-analysis",
            model=reward_model_path,
            tokenizer=rm_tokenizer,
            device_map="auto",
            model_kwargs={"torch_dtype": torch.bfloat16},
            truncation=True,
        )
        pipe_kwargs = {
            "top_k": None,
            "function_to_apply": "none",
            "batch_size": 1,
        }
        bos_token = (
            rm_tokenizer.bos_token if rm_tokenizer.bos_token is not None else ""
        )

        for p, r in zip(prompts, completions):
            contents = [(p + _).replace(bos_token, "") for _ in r]
            scores = [
                rm_pipeline(c, **pipe_kwargs)[0][0]["score"] for c in contents
            ]
            chosen_idx = np.argmax(scores)
            if scores[chosen_idx] > regenrate_reward_threshold:
                updated_sample_list.append(
                    {
                        "prompt": p,
                        "ref_response": r[chosen_idx],
                        "ref_score": scores[chosen_idx],
                    }
                )

        # release the reward model memory
        del rm_pipeline
        del rm_tokenizer
        torch.cuda.empty_cache()

        prompts_ = list(
            set(prompts) - set([_["prompt"] for _ in updated_sample_list])
        )

    
    ref_ds = datasets.Dataset.from_list(updated_sample_list)
    print("size of updated data after regeneration: ", len(updated_sample_list))
    
    invalid_ds = ds.select(invalid_indices)
    print("size of invalid ds after regeneration: ", len(invalid_indices))



    ref_df = ref_ds.to_pandas()
    invalid_df = invalid_ds.to_pandas()

    merged = ref_df.merge(invalid_df, on="prompt")
    result = []
    for _, row in merged.iterrows():
        result.append(
            {
                "prompt": row["prompt"],
                "chosen": row["ref_response"],
                "rejected": row["chosen"],
                "chosen_score": row["ref_score"],
                "rejected_score": row["chosen_score"],
            }
        )
        result.append(
            {
                "prompt": row["prompt"],
                "chosen": row["ref_response"],
                "rejected": row["rejected"],
                "chosen_score": row["ref_score"],
                "rejected_score": row["rejected_score"],
            }
        )

    updated_ds = datasets.Dataset.from_list(result)
    result_ds = datasets.concatenate_datasets(
        [ds.select(valid_indices), updated_ds]
    )

    print("size of ds after regeneration: ", len(result_ds))
    if not Path(output_dir).exists():
        Path(output_dir).mkdir(parents=True, exist_ok=True)
    output_path = Path(output_dir)
    result_ds.save_to_disk(str(output_path))


if __name__ == "__main__":
    main()
