from __future__ import annotations

import json
import os
import random
from dataclasses import dataclass, field
from pathlib import Path
from typing import *
from typing import Optional

import click
import datasets
import numpy as np
import pandas as pd
import torch
from accelerate import Accelerator
from alpaca_eval.decoders.huggingface_local import ListDataset
from datasets import load_dataset, load_from_disk
from tqdm import tqdm
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    HfArgumentParser,
    pipeline,
)


@click.command()
@click.option("--ds_path", type=str, required=True)
@click.option("--trained_model_path", type=str, required=True)
@click.option("--base_model_path", type=str, required=True)
@click.option("--reward_model_path", type=str, required=True)
@click.option("--output_dir", type=str, required=True)
@click.option(
    "--prompt_column",
    type=str,
    default="prompt",
    help="the name of the column containing the prompts",
)
@click.option(
    "--max_samples",
    type=int,
    default=None,
    help="the maximum number of samples to generate, if None, generate all samples",
)
@click.option("--max_new_tokens", type=int, default=2048)
@click.option("--do_sample", type=bool, default=True)
@click.option("--batch_size", type=int, default=4)
@click.option("--n_generation_per_prompt", type=int, default=2)
@click.option("--repetition_penalty", type=float, default=1.2)
@click.option("--temperature", type=float, default=0.9)
@click.option("--top_p", type=float, default=0.95)
@click.option("--num_beams", type=int, default=3)
def main(
    ds_path: str,
    trained_model_path: str,
    base_model_path: str,
    reward_model_path: str,
    output_dir: str,
    prompt_column: str = "prompt",
    max_samples: int = 20,
    max_new_tokens: int = 2048,
    do_sample: bool = True,
    batch_size: int = 4,
    n_generation_per_prompt: int = 2,
    repetition_penalty: float = 1.2,
    temperature: float = 0.5,
    top_p: float = 1.0,
    num_beams: int = 8,
):
    try:
        ds = datasets.load_dataset(ds_path)
    except:
        ds = datasets.load_from_disk(ds_path)
    if isinstance(ds, datasets.DatasetDict):
        ds = datasets.concatenate_datasets(list(ds.values()))

    if max_samples > len(ds):
        max_samples = len(ds)

    sampled_indices = random.sample(range(0, len(ds)), max_samples)

    sampled_ds = ds.select(sampled_indices)

    for exp_purpose in ["trained", "base"]:
        updated_sample_list: List[dict] = []

        if exp_purpose == "trained":
            generation_model_path = trained_model_path
        elif exp_purpose == "base":
            generation_model_path = base_model_path

        # generation
        tokenizer = AutoTokenizer.from_pretrained(
            generation_model_path,
            padding_side="left",
            torch_dtype=torch.bfloat16,
            device_map="auto",
        )
        model = AutoModelForCausalLM.from_pretrained(
            generation_model_path, torch_dtype=torch.bfloat16, device_map="auto"
        ).eval()

        if isinstance(sampled_ds[prompt_column][0], str):
            pass
        elif isinstance(sampled_ds[prompt_column][0], list):
            sampled_ds = sampled_ds.map(
                lambda x: {
                    prompt_column: tokenizer.apply_chat_template(
                        x[prompt_column],
                        tokenize=False,
                        add_generation_prompt=True,
                    )
                }
            )
        else:
            raise ValueError(
                f"Unsupported prompt type: {type(sampled_ds[prompt_column][0])}"
            )

        prompts = sampled_ds[prompt_column]

        if not tokenizer.pad_token_id:
            tokenizer.pad_token_id = tokenizer.eos_token_id
            tokenizer.pad_token = tokenizer.eos_token

        default_kwargs = dict(
            do_sample=do_sample,
            max_new_tokens=max_new_tokens,
            batch_size=batch_size,
            repetition_penalty=repetition_penalty,
            temperature=temperature,
            top_p=top_p,
            num_beams=num_beams,
            num_return_sequences=n_generation_per_prompt,
            eos_token_id=list(
                set([tokenizer.eos_token_id, tokenizer.pad_token_id])
            ),
        )

        print(f"Model memory: {model.get_memory_footprint() / 1e9} GB")
        print(default_kwargs)

        generation_pipeline = pipeline(
            "text-generation",
            model=model,
            tokenizer=tokenizer,
            **default_kwargs,
            trust_remote_code=True,
        )

        prompts_ds = ListDataset(prompts)
        completions = []
        for output in tqdm(
            generation_pipeline(
                prompts_ds,
                return_full_text=False,
                pad_token_id=tokenizer.pad_token_id,
            ),
            total=len(prompts_ds),
            desc=f"Generating completions for {exp_purpose} model",
        ):
            completions.append([_["generated_text"] for _ in output])

        if not Path(output_dir).exists() or not Path(output_dir).is_dir():
            Path(output_dir).mkdir(parents=True, exist_ok=True)

        output_ds = datasets.Dataset.from_dict(
            {
                "prompt": prompts,
                "responses": completions,
            }
        )

        output_ds.save_to_disk(
            str(output_dir) + "/" + exp_purpose + "_not_annotated"
        )

        # release the model memory
        del generation_pipeline
        model.cpu()
        del model
        del tokenizer
        torch.cuda.empty_cache()

    rm_tokenizer = AutoTokenizer.from_pretrained(reward_model_path)
    rm_pipeline = pipeline(
        "sentiment-analysis",
        model=reward_model_path,
        tokenizer=rm_tokenizer,
        device_map="auto",
        model_kwargs={"torch_dtype": torch.bfloat16},
        truncation=True,
    )
    pipe_kwargs = {
        "top_k": None,
        "function_to_apply": "none",
        "batch_size": 1,
    }

    for exp_purpose in ["trained", "base"]:

        ds_path = str(output_dir) + "/" + exp_purpose + "_not_annotated"

        try:
            ds = load_dataset(ds_path, split="train")
        except:
            ds = load_from_disk(ds_path)

        ds_processed = []

        bos_token = (
            rm_tokenizer.bos_token if rm_tokenizer.bos_token is not None else ""
        )

        for i in tqdm(
            range(len(ds)), desc=f"Annotating dataset {ds_path} / split train"
        ):
            item = ds[i]

            if "prompt" in item and "responses" in item:
                item = ds[i]
                prompt_and_responses = [
                    [
                        (item["prompt"] + response).replace(bos_token, ""),
                    ]
                    for response in item["responses"]
                ]
                scores = [  # ! check if any optimization can be done here
                    rm_pipeline(prompt_and_response, **pipe_kwargs)[0][0][
                        "score"
                    ]
                    for prompt_and_response in prompt_and_responses
                ]
                chosen_idx, rejected_idx = np.argmax(scores), np.argmin(scores)
                chosen_score, rejected_score = (
                    scores[chosen_idx],
                    scores[rejected_idx],
                )
                chosen_response = item["responses"][chosen_idx]
                rejected_response = item["responses"][rejected_idx]

                item.update(
                    {
                        "chosen_score": chosen_score,
                        "rejected_score": rejected_score,
                        "chosen": chosen_response,
                        "rejected": rejected_response,
                        "avg_score": np.mean(scores),
                    }
                )
                ds_processed.append(item)

            else:
                raise ValueError(
                    f"Dataset {ds_path} should contain 'chosen' and 'rejected' columns or 'prompt' and 'responses' columns."
                )

        ds_to_save = datasets.Dataset.from_list(ds_processed)

        ds_to_save.save_to_disk(
            str(output_dir) + "/" + exp_purpose + "_annotated"
        )

    trained_ds_path = str(output_dir) + "/trained_annotated"
    base_ds_path = str(output_dir) + "/base_annotated"

    trained_ds = load_from_disk(trained_ds_path)
    base_ds = load_from_disk(base_ds_path)

    # Create dictionaries for fast lookup
    trained_dict = {item["prompt"]: item["avg_score"] for item in trained_ds}
    base_dict = {item["prompt"]: item["avg_score"] for item in base_ds}

    # Find common prompts
    common_prompts = set(trained_dict.keys()) & set(base_dict.keys())

    # Count cases where avg_score from dataset1 is greater than dataset2
    count_higher_scores = sum(
        1
        for prompt in common_prompts
        if trained_dict[prompt] > base_dict[prompt]
    )

    # Total number of common prompts
    total_common_prompts = len(common_prompts)

    # Calculate the percentage
    percentage_higher = (
        (count_higher_scores / total_common_prompts) * 100
        if total_common_prompts > 0
        else 0
    )

    # Print the result
    print(f"The win rate of the trained model is {percentage_higher:.2f}%")


if __name__ == "__main__":
    main()
