import json
import os
import re
import shutil
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional, Union

import click
import datasets
import numpy as np
import torch
from accelerate import Accelerator
from datasets import load_dataset, load_from_disk
from tqdm import tqdm
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    HfArgumentParser,
    pipeline,
)


@click.command()
@click.option("--ds_path", type=str, required=True, help="Path to the dataset")
@click.option("--model_path", type=str, required=True, help="Path to the model")
@click.option(
    "--save_to", type=str, required=True, help="Path to save the results"
)
@click.option("--batch_size", type=int, default=8, help="Batch size")
@click.option("--local_rank", type=int, default=None, help="Local rank")
@click.option("--world_size", type=int, default=None, help="World size")
def main(
    ds_path: Union[str, datasets.Dataset],
    model_path: str,
    save_to: str,
    batch_size: int = 8,
    local_rank: Optional[int] = None,
    world_size: Optional[int] = None,
):
    try:
        ds = datasets.load_dataset(ds_path)
    except:
        ds = load_from_disk(ds_path)

    if isinstance(ds, datasets.DatasetDict):
        ds = datasets.concatenate_datasets([ds[split] for split in ds.keys()])

    # we generate unique IDs for each sample for tracking
    ds = ds.map(lambda _, i: {"__ID__": i}, with_indices=True, num_proc=8)

    if world_size is not None and local_rank is not None:
        selected_idxs = np.array_split(np.arange(len(ds)), world_size)[
            local_rank
        ]
        ds = ds.select(selected_idxs)

    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForSequenceClassification.from_pretrained(
        model_path,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        attn_implementation="eager",
        num_labels=1,
    )
    device = model.device

    rslt_ids = []
    rslt_chosen_hidden = []
    rslt_rejected_hidden = []
    rslt_prompt_hidden = []
    rslt_chosen_pred_scores = []
    rslt_rejected_pred_scores = []

    for start_idx in tqdm(
        range(0, len(ds), batch_size),
        desc=f"worker {local_rank} / {world_size}",
    ):
        ds_batch = ds.select(
            range(start_idx, min(start_idx + batch_size, len(ds)))
        )
        list_batch = ds_batch.to_list()
        text_pos = [
            tokenizer.apply_chat_template(
                _["chosen"], tokenize=False, add_generation_prompt=False
            ).replace(tokenizer.bos_token, "")
            for _ in list_batch
        ]
        text_neg = [
            tokenizer.apply_chat_template(
                _["rejected"], tokenize=False, add_generation_prompt=False
            ).replace(tokenizer.bos_token, "")
            for _ in list_batch
        ]
        text_prompt = [
            tokenizer.apply_chat_template(
                _["chosen"][:1], tokenize=False, add_generation_prompt=False
            ).replace(tokenizer.bos_token, "")
            for _ in list_batch
        ]

        tokenized_pos = tokenizer(
            text_pos, padding=True, truncation=True, return_tensors="pt"
        ).to(device)
        tokenized_neg = tokenizer(
            text_neg, padding=True, truncation=True, return_tensors="pt"
        ).to(device)
        tokenized_prompt = tokenizer(
            text_prompt, padding=True, truncation=True, return_tensors="pt"
        ).to(device)

        with torch.no_grad():
            output_pos = model(**tokenized_pos, output_hidden_states=True)
            output_neg = model(**tokenized_neg, output_hidden_states=True)
            output_prompt = model(**tokenized_prompt, output_hidden_states=True)

        def extract_last_hidden(output, tokenized):
            hidden_states = output.hidden_states[-1]  # Last layer
            attention_mask = tokenized["attention_mask"]
            # Find the index of the last non-padded token for each sequence
            lengths = (
                attention_mask.sum(dim=1) - 1
            )  # Subtract 1 for zero-based indexing
            # Gather the hidden states at the last token position for each sequence
            last_hidden = hidden_states[
                torch.arange(hidden_states.size(0)), lengths
            ]
            return last_hidden.detach().to(torch.float).cpu().numpy().tolist()

        pos_hidden = extract_last_hidden(output_pos, tokenized_pos)
        neg_hidden = extract_last_hidden(output_neg, tokenized_neg)
        prompt_hidden = extract_last_hidden(output_prompt, tokenized_prompt)

        chosen_pred_scores = (
            output_pos.logits[:, 0].detach().to(torch.float).cpu().numpy()
        )
        rejected_pred_scores = (
            output_neg.logits[:, 0].detach().to(torch.float).cpu().numpy()
        )

        for i, d in enumerate(list_batch):
            rslt_ids.append(d["__ID__"])
            rslt_chosen_hidden.append(pos_hidden[i])
            rslt_rejected_hidden.append(neg_hidden[i])
            rslt_prompt_hidden.append(prompt_hidden[i])
            rslt_chosen_pred_scores.append(chosen_pred_scores[i])
            rslt_rejected_pred_scores.append(rejected_pred_scores[i])

    rslt_dict = {
        "id": rslt_ids,
        "chosen_hidden": rslt_chosen_hidden,
        "rejected_hidden": rslt_rejected_hidden,
        "prompt_hidden": rslt_prompt_hidden,
        "chosen_pred_score": rslt_chosen_pred_scores,
        "rejected_pred_score": rslt_rejected_pred_scores,
    }

    rslt_ds = datasets.Dataset.from_dict(rslt_dict)

    if world_size is not None and local_rank is not None:
        save_to_str = save_to + f"_r{local_rank}_w{world_size}"
    else:
        save_to_str = save_to

    rslt_ds.save_to_disk(save_to_str)

    pattern = r".*_r\d+_w\d+$"
    matched_dirs = [
        p
        for p in Path(Path(save_to).parent).iterdir()
        if p.is_dir() and re.search(pattern, str(p))
    ]
    if len(matched_dirs) == world_size:
        # merge all results
        ds_list = [datasets.load_from_disk(str(p)) for p in matched_dirs]
        merged_ds = datasets.concatenate_datasets(ds_list)
        merged_ds.save_to_disk(save_to)
        # remove temporary directories
        for p in matched_dirs:
            shutil.rmtree(str(p))


if __name__ == "__main__":
    main()
