import json
from argparse import ArgumentParser
from collections import defaultdict
from pathlib import Path

import numpy as np
import torch
from datasets import load_from_disk
from tqdm import tqdm
from vllm import LLM, TokensPrompt

from fair_gpt.evaluation.text_classification import GenderIdentification
from fair_gpt.generation.vllm_generation import build_prompt, generate_with_vllm


def make_few_shot_prompts(data, n_samples, few_shots, tokenizer, is_chat, rng):
    prompt_list = []
    for i in range(n_samples):
        # select few shots random indices using numpy
        if few_shots > 0:
            few_shot_indices = rng.choice(len(data), size=few_shots, replace=True)
            few_shot_examples = [data[int(idx)] for idx in few_shot_indices]
        else:
            few_shot_examples = []
        prompt = build_prompt(
            few_shots=few_shot_examples, chat=is_chat, tokenizer=tokenizer
        )
        # check if tokenizer removed last \n
        if not prompt.endswith("\n"):
            prompt += "\n"
        token_ids = {"prompt_token_ids": tokenizer(prompt).input_ids}
        prompt_list.append(TokensPrompt(token_ids))
    print("First prompt:")
    print(tokenizer.decode(prompt_list[0]["prompt_token_ids"]))
    return prompt_list


def generate_until_quota_reached(
    llm,
    tokenizer,
    filter_fn,
    sampling_params,
    chat,
    few_shots,
    quota=6000,
    batch_size=10000,
    data_few_shots=None,
):
    text_list = []
    gender_list = []
    n_generated = 0
    gender_count = defaultdict(int)
    with tqdm(total=quota, desc="Generating samples") as pbar:
        while n_generated < quota:
            prev_n_generated = n_generated
            prompt_list = make_few_shot_prompts(
                data=data_few_shots,
                n_samples=batch_size,
                few_shots=few_shots,
                tokenizer=tokenizer,
                is_chat=chat,
                rng=rng,
            )

            batch_texts = generate_with_vllm(
                llm=llm,
                prompt_token_ids=prompt_list,
                sampling_params=sampling_params,
                chat=chat,
            )
            # post-process to keep only first paragraph after "Biography:"
            for i in range(len(batch_texts)):
                if "\n\n" in batch_texts[i]:
                    batch_texts[i] = batch_texts[i].split("\n\n")[0].strip()
                gender = filter_fn(batch_texts[i])
                gender_count[gender] += 1
                gender_list.append(gender)
            text_list.extend(batch_texts)
            # n_generated is the min of each gender count
            print(gender_count)
            n_generated = min(gender_count.values(), default=0)
            pbar.update(n_generated - prev_n_generated)

    results_list = []
    for i in range(len(text_list)):
        results_list.append({"text": text_list[i], "gender": gender_list[i]})
    return results_list


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument(
        "--few_shots",
        type=int,
        default=0,
        help="Number of few-shot examples to include in the prompt.",
    )
    parser.add_argument(
        "--n_samples",
        type=int,
        default=6000,
        help="Number of samples to generate.",
    )

    parser.add_argument(
        "--model_name",
        type=str,
        default="qwen-4b-chat",
        help="HuggingFace model name or path.",
    )
    parser.add_argument(
        "--data_name",
        type=str,
        default="data_dir/wikibio_train",
        help="Path to input JSONL file with 'content' field.",
    )

    parser.add_argument(
        "--max_new_tokens",
        type=int,
        default=512,
        help="Maximum number of new tokens to generate.",
    )
    parser.add_argument(
        "--max_model_length",
        type=int,
        default=8192,
        help="Maximum number of tokens for the model context.",
    )
    parser.add_argument(
        "--output_path",
        type=str,
        default="results_dir/wikibio/generation.jsonl",
        help="Path to output directory for generated samples.",
    )
    parser.add_argument(
        "--gen_mode",
        type=str,
        default="base",
        help="Generation mode to use.",
    )

    args = parser.parse_args()

    from transformers import AutoTokenizer

    if "llama2" in args.model_name and "chat" in args.model_name:
        tokenizer = AutoTokenizer.from_pretrained("models_dir/llama2-tokenizer")
    else:
        tokenizer = AutoTokenizer.from_pretrained(args.model_name)
    is_chat = (
        hasattr(tokenizer, "chat_template") and tokenizer.chat_template is not None
    )
    print(f"Is chat model: {is_chat}")
    print(f"Chat template: {getattr(tokenizer, 'chat_template', None)}")

    n_devices = torch.cuda.device_count()

    data = load_from_disk(args.data_name)
    rng = np.random.default_rng(1234)

    # print the first decoded prompt

    llm = LLM(
        model=args.model_name,
        max_model_len=args.max_model_length,
        tensor_parallel_size=n_devices,
        generation_config="auto",
        seed=1234,
    )
    sampling = llm.get_default_sampling_params()
    sampling.max_tokens = args.max_new_tokens
    sampling.min_tokens = 64
    sampling.stop = ["\n\n", "\n\n\n", "\n\n\n\n", "Biography:"]
    gender_identifier = GenderIdentification()
    filter_fn = gender_identifier.classify_gender_text
    text_list = generate_until_quota_reached(
        llm=llm,
        tokenizer=tokenizer,
        filter_fn=filter_fn,
        sampling_params=sampling,
        chat=is_chat,
        few_shots=args.few_shots,
        quota=args.n_samples,
        batch_size=10000,
        data_few_shots=data,
    )

    for i in range(min(20, len(text_list))):
        print("=== Example", i, "===")
        print(text_list[i])
        print()

    output_path = Path(args.output_path)
    output_path.parent.mkdir(parents=True, exist_ok=True)

    with open(output_path, "w") as f:
        for item in text_list:
            f.write(json.dumps(item) + "\n")
        print(f"Saved results to {output_path}")
