import json
from collections import defaultdict
from dataclasses import dataclass, field
from pathlib import Path

import hydra
import numpy as np
import torch
from datasets import load_from_disk
from hydra.core.config_store import ConfigStore
from tqdm import tqdm
from transformers import AutoTokenizer
from vllm import LLM, TokensPrompt
from vllm.lora.request import LoRARequest

from fair_gpt.evaluation.text_classification import GenderIdentification
from fair_gpt.generation.vllm_generation import generate_with_vllm
from fair_gpt.training.lora_training import (
    DataConfig,
    ModelConfig,
    TrainConfig,
    build_prompt_tokenized,
    make_run_name,
)


def make_prompts(n_samples, tokenizer, is_chat):
    prompt_list = []
    for i in range(n_samples):
        # select few shots random indices using numpy

        prompt = build_prompt_tokenized(
            tokenizer=tokenizer,
            text=None,
            gender=None,
            inference_mode=True,
            is_chat=is_chat,
            tokenize=False,
        )
        # check if tokenizer removed last \n
        if not prompt.endswith("\n"):
            prompt += "\n"
        token_ids = {"prompt_token_ids": tokenizer(prompt).input_ids}
        prompt_list.append(TokensPrompt(token_ids))
    print(tokenizer.decode(prompt_list[0]["prompt_token_ids"]))

    return prompt_list


def generate_until_quota_reached(
    llm,
    tokenizer,
    filter_fn,
    sampling_params,
    chat,
    quota=6000,
    batch_size=10000,
    lora=None,
    gen_mode="base",
):
    if gen_mode != "base":
        raise ValueError(f"Unknown gen_mode: {gen_mode}")
    text_list = []
    gender_list = []
    n_generated = 0
    gender_count = defaultdict(int)
    with tqdm(total=quota, desc="Generating samples") as pbar:
        while n_generated < quota:
            prev_n_generated = n_generated
            prompt_list = make_prompts(
                n_samples=batch_size,
                tokenizer=tokenizer,
                is_chat=chat,
            )

            batch_texts = generate_with_vllm(
                llm=llm,
                prompt_token_ids=prompt_list,
                sampling_params=sampling_params,
                chat=chat,
                lora=lora,
            )
            # post-process to keep only first paragraph after "Biography:"
            for i in range(len(batch_texts)):
                if "\n\n" in batch_texts[i]:
                    batch_texts[i] = batch_texts[i].split("\n\n")[0].strip()
                gender = filter_fn(batch_texts[i])
                gender_count[gender] += 1
                gender_list.append(gender)
            text_list.extend(batch_texts)
            # n_generated is the min of each gender count
            print(gender_count)
            n_generated = min(gender_count.values(), default=0)
            pbar.update(n_generated - prev_n_generated)

    results_list = []
    for i in range(len(text_list)):
        results_list.append({"text": text_list[i], "gender": gender_list[i]})
    return results_list


@dataclass
class EvalConfig:
    n_samples: int = 6000
    max_new_tokens: int = 512
    gen_mode: str = "base"
    prop_male: float = 0.5


@dataclass
class Config:
    data: DataConfig = field(default_factory=DataConfig)
    model: ModelConfig = field(default_factory=ModelConfig)
    train: TrainConfig = field(default_factory=TrainConfig)
    eval: EvalConfig = field(default_factory=EvalConfig)


if __name__ == "__main__":
    cs = ConfigStore.instance()
    cs.store(name="config", node=Config)

    @hydra.main(version_base=None, config_name="config")
    def main(cfg: Config):
        run_name = make_run_name(cfg)
        output_dir = Path(cfg.train.output_dir) / run_name
        model_path = output_dir / "best_model"
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        is_chat = (
            hasattr(tokenizer, "chat_template") and tokenizer.chat_template is not None
        )
        print(f"Is chat model: {is_chat}")

        n_devices = torch.cuda.device_count()

        if cfg.model.lora_r > 0:
            enable_lora = True
            lora_request = LoRARequest("adapter", 1, str(model_path))
        else:
            enable_lora = False
            lora_request = None
        print(f"LoRA enabled: {enable_lora}")

        llm = LLM(
            model=f"models_dir/{cfg.model.model_name}",
            max_model_len=cfg.model.max_length,
            tensor_parallel_size=n_devices,
            generation_config="auto",
            seed=1234,
            enable_lora=enable_lora,
        )
        sampling = llm.get_default_sampling_params()
        sampling.max_tokens = cfg.eval.max_new_tokens
        sampling.min_tokens = 64
        sampling.stop = ["\n\n", "\n\n\n", "\n\n\n\n", "Biography:"]
        gender_identifier = GenderIdentification()
        filter_fn = gender_identifier.classify_gender_text
        text_list = generate_until_quota_reached(
            llm=llm,
            tokenizer=tokenizer,
            filter_fn=filter_fn,
            sampling_params=sampling,
            chat=is_chat,
            quota=cfg.eval.n_samples,
            batch_size=10000,
            lora=lora_request,
            gen_mode=cfg.eval.gen_mode,
        )
        # post-process to keep only first paragraph after "Biography:"
        for i in range(len(text_list)):
            if "\n\n" in text_list[i]:
                text_list[i] = text_list[i].split("\n\n")[0].strip()

        for i in range(min(20, len(text_list))):
            print("=== Example", i, "===")
            print(text_list[i])
            print()

        output_path = (
            output_dir / "eval" / cfg.eval.gen_mode / "wikibio_generation.jsonl"
        )
        output_path.parent.mkdir(parents=True, exist_ok=True)

        with open(output_path, "w") as f:
            for item in text_list:
                f.write(json.dumps(item) + "\n")
            print(f"Saved results to {output_path}")

    main()
