import json
from collections import defaultdict
from dataclasses import dataclass, field
from pathlib import Path

import hydra
import numpy as np
import torch
from datasets import load_from_disk
from hydra.core.config_store import ConfigStore
from tqdm import tqdm
from transformers import AutoTokenizer
from vllm import LLM, TokensPrompt
from vllm.lora.request import LoRARequest

from fair_gpt.evaluation.text_classification import GenderIdentification
from fair_gpt.generation.vllm_generation import generate_with_vllm
from fair_gpt.training.lora_training import (
    DataConfig,
    ModelConfig,
    TrainConfig,
    build_prompt_tokenized,
    make_run_name,
)


def make_prompts(n_samples_male, n_samples_female, tokenizer, is_chat):
    prompt_list_male = []
    prompt_list_female = []
    for i in range(n_samples_male):
        # select few shots random indices using numpy

        prompt = build_prompt_tokenized(
            tokenizer=tokenizer,
            text=None,
            gender="male",
            inference_mode=True,
            is_chat=is_chat,
            tokenize=False,
        )
        # check if tokenizer removed last \n
        if not prompt.endswith("\n"):
            prompt += "\n"
        token_ids = {"prompt_token_ids": tokenizer(prompt).input_ids}
        prompt_list_male.append(TokensPrompt(token_ids))
    for i in range(n_samples_female):
        prompt = build_prompt_tokenized(
            tokenizer=tokenizer,
            text=None,
            gender="female",
            inference_mode=True,
            is_chat=is_chat,
            tokenize=False,
        )
        if not prompt.endswith("\n"):
            prompt += "\n"
        token_ids = {"prompt_token_ids": tokenizer(prompt).input_ids}
        prompt_list_female.append(TokensPrompt(token_ids))
    return prompt_list_male, prompt_list_female


def generate_proportions(
    llm,
    tokenizer,
    sampling_params,
    chat,
    n_samples_male,
    n_samples_female,
    lora=None,
    gen_mode="base",
):
    if gen_mode != "base":
        raise ValueError(f"Unknown gen_mode: {gen_mode}")
    text_list = []
    prompt_list_male, prompt_list_female = make_prompts(
        n_samples_male=n_samples_male,
        n_samples_female=n_samples_female,
        tokenizer=tokenizer,
        is_chat=chat,
    )

    batch_texts_male = generate_with_vllm(
        llm=llm,
        prompt_token_ids=prompt_list_male,
        sampling_params=sampling_params,
        chat=chat,
        lora=lora,
    )
    batch_texts_female = generate_with_vllm(
        llm=llm,
        prompt_token_ids=prompt_list_female,
        sampling_params=sampling_params,
        chat=chat,
        lora=lora,
    )

    for i in range(len(batch_texts_male)):
        if "\n\n" in batch_texts_male[i]:
            batch_texts_male[i] = batch_texts_male[i].split("\n\n")[0].strip()
    for i in range(len(batch_texts_female)):
        if "\n\n" in batch_texts_female[i]:
            batch_texts_female[i] = batch_texts_female[i].split("\n\n")[0].strip()
    text_list.extend(batch_texts_male)
    text_list.extend(batch_texts_female)
    gender_list = ["male"] * len(batch_texts_male) + ["female"] * len(
        batch_texts_female
    )

    results_list = []
    for i in range(len(text_list)):
        results_list.append({"text": text_list[i], "gender": gender_list[i]})
    return results_list


@dataclass
class EvalConfig:
    n_samples: int = 6000
    max_new_tokens: int = 512
    gen_mode: str = "base"
    prop_male: float = 0.5


@dataclass
class Config:
    data: DataConfig = field(default_factory=DataConfig)
    model: ModelConfig = field(default_factory=ModelConfig)
    train: TrainConfig = field(default_factory=TrainConfig)
    eval: EvalConfig = field(default_factory=EvalConfig)


if __name__ == "__main__":
    cs = ConfigStore.instance()
    cs.store(name="config", node=Config)

    @hydra.main(version_base=None, config_name="config")
    def main(cfg: Config):
        run_name = make_run_name(cfg)
        output_dir = Path(cfg.train.output_dir) / run_name
        model_path = output_dir / "best_model"
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        is_chat = (
            hasattr(tokenizer, "chat_template") and tokenizer.chat_template is not None
        )
        print(f"Is chat model: {is_chat}")

        n_devices = torch.cuda.device_count()

        if cfg.model.lora_r > 0:
            enable_lora = True
            lora_request = LoRARequest("adapter", 1, str(model_path))
        else:
            enable_lora = False
            lora_request = None
        print(f"LoRA enabled: {enable_lora}")

        llm = LLM(
            model=f"models_dir/{cfg.model.model_name}",
            max_model_len=cfg.model.max_length,
            tensor_parallel_size=n_devices,
            generation_config="auto",
            seed=1234,
            enable_lora=enable_lora,
        )
        sampling = llm.get_default_sampling_params()
        sampling.max_tokens = cfg.eval.max_new_tokens
        sampling.min_tokens = 64
        sampling.stop = ["\n\n", "\n\n\n", "\n\n\n\n", "Biography:"]
        min_samples = cfg.eval.n_samples
        if cfg.eval.prop_male >= 0.5:
            n_samples_female = min_samples

            n_samples_male = int(
                n_samples_female * cfg.eval.prop_male / (1 - cfg.eval.prop_male)
            )
        else:
            n_samples_male = min_samples
            n_samples_female = int(
                n_samples_male * (1 - cfg.eval.prop_male) / cfg.eval.prop_male
            )

        text_list = generate_proportions(
            llm=llm,
            tokenizer=tokenizer,
            sampling_params=sampling,
            chat=is_chat,
            n_samples_male=n_samples_male,
            n_samples_female=n_samples_female,
            lora=lora_request,
            gen_mode=cfg.eval.gen_mode,
        )
        # post-process to keep only first paragraph after "Biography:"
        for i in range(len(text_list)):
            if "\n\n" in text_list[i]:
                text_list[i] = text_list[i].split("\n\n")[0].strip()

        for i in range(min(20, len(text_list))):
            print("=== Example", i, "===")
            print(text_list[i])
            print()

        output_path = (
            output_dir
            / "eval"
            / cfg.eval.gen_mode
            / f"prop_m{cfg.eval.prop_male}"
            / "wikibio_generation.jsonl"
        )
        output_path.parent.mkdir(parents=True, exist_ok=True)

        with open(output_path, "w") as f:
            for item in text_list:
                f.write(json.dumps(item) + "\n")
            print(f"Saved results to {output_path}")

    main()
