import json
import math
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, Iterable, List, Optional

import hydra
import numpy as np
import pyarrow as pa
import pyarrow.parquet as pq
import torch
from datasets import Dataset
from hydra.core.config_store import ConfigStore
from transformers import AutoModelForCausalLM, AutoTokenizer
from vllm import LLM

from fair_gpt.generation.compute_pr_v2 import launch_nll_eval, launch_pr_eval
from fair_gpt.training.lora_training import (
    DataConfig,
    ModelConfig,
    TrainConfig,
    make_run_name,
)

TASK_DESCRIPTION = "Find similar texts, according to both content, writing style, etc."


def get_detailed_instruct(query: str) -> str:
    return f"Instruct: {TASK_DESCRIPTION}\nText:{query}"


# Assume LLM, geom_mean already imported


def _read_jsonl_in_chunks(path: Path, chunk_size: int) -> Iterable[List[Dict]]:
    """Yield lists of JSON objects from a JSONL file."""
    buf = []
    with open(path, "r") as f:
        for line in f:
            buf.append(json.loads(line))
            if len(buf) >= chunk_size:
                yield buf
                buf = []
    if buf:
        yield buf


def compute_embeddings_streaming(
    embed_model_name: str,
    gen_files: Path,
    batch_size: int = 4096,  # pick a safe default, adjust to your GPU or API limits
    writer_rows: int = 4096,  # number of rows per parquet row group
):
    """
    Stream JSONL, embed in batches, write parquet shards per batch.
    Avoids holding full text and embeddings in memory.
    """
    model = LLM(model=f"models_dir/{embed_model_name}", task="embed")
    out_dir = gen_files.parent / f"{embed_model_name}_embed"
    out_dir.mkdir(parents=True, exist_ok=True)

    print(f"Streaming from {gen_files}")
    shard_idx = 0
    total = 0
    shard_paths = []

    for raw_chunk in _read_jsonl_in_chunks(gen_files, batch_size):
        texts = [r["text"] for r in raw_chunk]
        genders = [r.get("gender") for r in raw_chunk]

        # Embed this chunk
        emb_outputs = model.embed(texts)
        # Convert to compact float32 arrays
        embeddings = [
            np.asarray(item.outputs.embedding, dtype=np.float32).tolist()
            for item in emb_outputs
        ]

        # Build an Arrow table for this shard
        table = pa.table(
            {
                "text": pa.array(texts, type=pa.large_string()),
                "gender": pa.array(genders, type=pa.large_string()),
                # list<primitive> column for embeddings
                "embedding": pa.array(embeddings, type=pa.list_(pa.float32())),
            }
        )

        shard_path = out_dir / f"part-{shard_idx:05d}.parquet"
        # Use a single-row-group writer to keep memory stable
        pq.write_table(
            table,
            shard_path,
            compression="zstd",
            use_dictionary=False,
            row_group_size=writer_rows,
        )

        shard_paths.append(str(shard_path))
        shard_idx += 1
        total += len(texts)
        print(f"Wrote {len(texts)} rows to {shard_path}")

    if total == 0:
        raise ValueError(f"No samples found in {gen_files}")

    # Materialize an HF Dataset from the shard files, then save to disk
    print("Building Hugging Face Dataset from parquet shards")
    ds = Dataset.from_parquet(shard_paths)
    ds = ds.with_format(None)  # ensure default format
    ds.save_to_disk(out_dir)
    print(f"Saved {len(ds)} rows to {out_dir}")


@dataclass
class EvalConfig:
    n_samples: int = 15000
    max_new_tokens: int = 512
    embed_model_name: str = "qwen3-embed-4b"
    ref_embed_path: Optional[str] = None
    do_embed: bool = True
    seed: int = 1234
    n_runs: int = 40
    k: int = 4
    pca_var: float = 0.9
    max_samples_pr: int = 3000
    gen_mode: str = "base"
    batch_size: int = 2
    compute_nll: bool = True
    compute_pr: bool = True
    prop_male: float = 0.5


@dataclass
class Config:
    data: DataConfig = field(default_factory=DataConfig)
    model: ModelConfig = field(default_factory=ModelConfig)
    train: TrainConfig = field(default_factory=TrainConfig)
    eval: EvalConfig = field(default_factory=EvalConfig)


if __name__ == "__main__":
    cs = ConfigStore.instance()
    cs.store(name="config", node=Config)

    @hydra.main(version_base=None, config_name="config")
    def main(cfg: Config):
        if cfg.eval.ref_embed_path is None:
            cfg.eval.ref_embed_path = (
                f"results_dir/wikibio/{cfg.eval.embed_model_name}_ref_embed"
            )
        run_name = make_run_name(cfg)
        output_dir = Path(cfg.train.output_dir) / run_name
        if cfg.model.conditional:
            gen_files = (
                output_dir
                / "eval"
                / cfg.eval.gen_mode
                / f"prop_m{cfg.eval.prop_male}"
                / "wikibio_generation.jsonl"
            )
        else:
            gen_files = (
                output_dir / "eval" / cfg.eval.gen_mode / "wikibio_generation.jsonl"
            )
        if cfg.eval.compute_pr:
            if cfg.eval.do_embed:
                compute_embeddings_streaming(cfg.eval.embed_model_name, gen_files)
            launch_pr_eval(
                ref_embed_path=cfg.eval.ref_embed_path,
                gen_embed_path=gen_files.parent / f"{cfg.eval.embed_model_name}_embed",
                output_path=gen_files.parent
                / f"{cfg.eval.embed_model_name}_embed_pr_results.json",
                k=cfg.eval.k,
                max_samples=cfg.eval.max_samples_pr,
                n_runs=cfg.eval.n_runs,
                pca_var=cfg.eval.pca_var,
                seed=cfg.eval.seed,
            )
        if cfg.eval.compute_nll:
            run_name = make_run_name(cfg)
            output_dir = Path(cfg.train.output_dir) / run_name
            model_path = output_dir / "best_model"
            tokenizer = AutoTokenizer.from_pretrained(model_path)
            is_chat = (
                hasattr(tokenizer, "chat_template")
                and tokenizer.chat_template is not None
            )
            print(f"Is chat model: {is_chat}")
            if cfg.model.lora_r > 0:
                model = AutoModelForCausalLM.from_pretrained(
                    f"models_dir/{cfg.model.model_name}",
                    device_map="auto",
                    torch_dtype=torch.bfloat16,
                    low_cpu_mem_usage=True,
                )
                model.load_adapter(model_path, adapter_name="adapter")
                model.set_adapter("adapter")
            else:
                model = AutoModelForCausalLM.from_pretrained(
                    model_path,
                    device_map="auto",
                    torch_dtype=torch.bfloat16,
                    low_cpu_mem_usage=True,
                )
            output_path = (
                gen_files.parent
                / f"{cfg.eval.embed_model_name}_embed_nll_results.json",
            )
            launch_nll_eval(
                cfg=cfg,
                model=model,
                tokenizer=tokenizer,
                is_chat=is_chat,
                output_path=output_path,
            )

    main()
