import math
import os
import random
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Optional

import hydra
import numpy as np
import torch
import wandb
from datasets import DatasetDict, load_from_disk
from hydra.core.config_store import ConfigStore

# Import PEFT and LoRA components
from peft import LoraConfig, get_peft_model
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
)

from fair_gpt.generation.utils import (
    CHAT_FEW_SHOT_PRE_PROMPT,
    CHAT_ZERO_SHOT,
    PRETRAINED_FEW_SHOT,
    SYSTEM_PROMPT,
)
from fair_gpt.training.loss_fn import CustomLossTrainer
from fair_gpt.utils import DataCollatorForLM

DEFAULT_TEMPLATE = "{% for message in messages %}{{ message['role']|capitalize }}: {{ message['content'] }}\n\n{% endfor %}"


def set_seed(seed_value):
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed_value)
        torch.cuda.manual_seed_all(seed_value)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


@dataclass
class DataConfig:
    dataset_name: str = "ag_news"
    n_train_samples: Optional[int] = None
    n_eval_samples: Optional[int] = None
    val_split: Optional[str] = None
    train_split: str = "train"
    text_field: str = "content"


@dataclass
class ModelConfig:
    model_name: str = "gpt2"
    max_length: int = 1024
    lora_r: int = 8
    lora_alpha_coef: int = 2
    dropout: float = 0.05
    conditional: bool = False


@dataclass
class TrainConfig:
    # General settings
    output_dir: str = "results_dir/wikibio-train"
    seed: int = 123
    bf16: bool = True
    slurm: bool = True

    # Training loop parameters
    train_steps: int = 1000
    global_batch_size: int = 64
    per_device_batch_size: int = 1
    eval_steps: int = 100
    save_total_limit: int = 3

    # Logging and evaluation settings
    logging_steps: int = 5

    # Optimization parameters
    learning_rate: float = 1e-4
    weight_decay: float = 0.01
    lr_scheduler_type: str = "cosine"
    warmup_ratio: float = 0.05
    warmup_steps: Optional[int] = 1000
    adam_beta1: float = 0.9
    adam_beta2: float = 0.95
    adam_epsilon: float = 1e-8

    wandb_project: str = "ft-wikibio"

    loss_fn: str = "default"


@dataclass
class Config:
    data: DataConfig = field(default_factory=DataConfig)
    model: ModelConfig = field(default_factory=ModelConfig)
    train: TrainConfig = field(default_factory=TrainConfig)


def build_prompt_chat(text, gender=None, inference_mode=False) -> List[dict]:
    messages = [
        {"role": "system", "content": SYSTEM_PROMPT},
    ]
    user_prompt = CHAT_ZERO_SHOT
    if gender is not None:
        user_prompt += f" It should be a biography of a {gender} person."
    messages = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": user_prompt},
    ]
    if inference_mode:
        messages.append({"role": "assistant", "content": "Biography:\n"})
    else:
        messages.append({"role": "assistant", "content": f"Biography:\n{text}"})
    return messages


def build_prompt_pretrained(
    text, eos_token, bos_token, gender=None, inference_mode=False
) -> str:
    if bos_token is None:
        bos_token = ""
    prompt = bos_token + PRETRAINED_FEW_SHOT
    if gender is not None:
        prompt += f" Biography of a {gender} person:"
    if inference_mode:
        prompt += "\n\nBiography:\n"
    else:
        prompt += f"\n\nBiography:\n{text}\n\n" + eos_token
    return prompt


def build_prompt_tokenized(
    tokenizer,
    text=None,
    gender=None,
    inference_mode=False,
    is_chat=True,
    tokenize=False,
) -> dict:
    if is_chat:
        messages = build_prompt_chat(text, gender=gender, inference_mode=inference_mode)
        prompt = tokenizer.apply_chat_template(
            messages,
            continue_final_message=inference_mode,
            tokenize=False,
            date_string="10 Sep 2025",
        )
    else:
        prompt = build_prompt_pretrained(
            text,
            eos_token=tokenizer.eos_token,
            bos_token=tokenizer.bos_token,
            gender=gender,
            inference_mode=inference_mode,
        )
    if tokenize:
        prompt = tokenizer(
            prompt,
            truncation=True,
            max_length=tokenizer.model_max_length,
            padding=False,
            return_attention_mask=False,
            add_special_tokens=False,
        )
    return prompt


def simple_tokenize_fn(
    tokenizer,
    text_key: str,
    max_len: int,
    is_chat: bool = True,
    conditional: bool = False,
    inference_mode: bool = False,
):
    gender_map = {"male": 0, "female": 1, "other": -1}

    def _tok(batch):
        texts = batch[text_key]
        genders = batch.get("gender", None) if conditional else None

        token_batch = []
        gender_bins = []

        for i, text in enumerate(texts):
            gender = None
            if genders is not None:
                g = genders[i]
                if g is not None:
                    g = str(g).strip().lower()
                gender = g
                # build the numeric mapping in parallel so it stays aligned with texts
                gender_bins.append(gender_map.get(g, -1))

            prompt = build_prompt_tokenized(
                tokenizer,
                text=text,
                gender=gender,
                inference_mode=inference_mode,
                is_chat=is_chat,
            )

            token_batch.append(prompt)

        out = tokenizer(
            token_batch,
            truncation=True,
            max_length=max_len,
            padding=False,
            return_attention_mask=False,
            add_special_tokens=False,
        )

        out["labels"] = out["input_ids"].copy()

        if conditional and genders is not None:
            out["gender_bin"] = gender_bins

        return out

    return _tok


def process_data(
    cfg: Config, tokenizer, is_chat: bool, inference_mode=False
) -> DatasetDict:
    ds = load_from_disk(cfg.data.dataset_name)

    tok_fn = simple_tokenize_fn(
        tokenizer=tokenizer,
        text_key=cfg.data.text_field,
        max_len=cfg.model.max_length,
        is_chat=is_chat,
        conditional=cfg.model.conditional,  # critical fix
        inference_mode=inference_mode,
    )

    ds = ds.map(tok_fn, batched=True, num_proc=8)

    if cfg.model.conditional:
        columns = ["input_ids", "labels", "weights", "gender_bin"]

        gender_array = np.array(ds["gender_bin"])
        prop_male = np.mean(gender_array == 0)
        prop_female = np.mean(gender_array == 1)
        ds = ds.map(
            lambda x: {
                "weights": 0.5 / (prop_male if x["gender_bin"] == 0 else prop_female)
            },
            num_proc=8,
        )

        print(f"Proportion of male samples: {prop_male:.2f}")
        print(f"Proportion of female samples: {prop_female:.2f}")
    else:
        columns = ["input_ids", "labels"]
        ds = ds.select_columns(columns)
    return ds


def prepare_data(cfg: Config, tokenizer, is_chat: bool) -> DatasetDict:
    ds = process_data(cfg, tokenizer, is_chat=is_chat)
    # select only male or female gender samples
    dataset = ds.train_test_split(test_size=0.1, seed=cfg.train.seed)
    train_tok = dataset["train"]
    eval_tok = dataset["test"]

    if cfg.data.n_train_samples:
        train_tok = train_tok.select(
            range(min(cfg.data.n_train_samples, len(train_tok)))
        )
    if cfg.data.n_eval_samples:
        eval_tok = eval_tok.select(range(min(cfg.data.n_eval_samples, len(eval_tok))))

    return DatasetDict(train=train_tok, validation=eval_tok)


def compute_grad_accum(global_bs: int, per_device_bs: int) -> int:
    world_size = int(os.environ.get("WORLD_SIZE", "1"))
    denom = max(1, per_device_bs * world_size)
    return max(1, (global_bs + denom - 1) // denom)


def make_run_name(cfg: Config) -> str:
    return f"lora/{cfg.model.model_name.replace('/', '-')}/loss{cfg.train.loss_fn}/cond{cfg.model.conditional}/r{cfg.model.lora_r}/bs{cfg.train.global_batch_size}/lr{cfg.train.learning_rate}/seed{cfg.train.seed}"


if __name__ == "__main__":
    cs = ConfigStore.instance()
    cs.store(name="config", node=Config)

    @hydra.main(version_base=None, config_name="config")
    def main(cfg: Config):
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
        set_seed(cfg.train.seed)

        tokenizer = AutoTokenizer.from_pretrained(f"models_dir/{cfg.model.model_name}")
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
        tokenizer.padding_side = "left"
        if not hasattr(tokenizer, "chat_template") or tokenizer.chat_template is None:
            is_chat = False
            print("Pretrained model detected, using pretrained prompt format.")
        else:
            is_chat = True
            print("Chat model detected, using chat prompt format.")
        data = prepare_data(cfg, tokenizer, is_chat=is_chat)
        print(
            f"Train samples: {len(data['train'])}, Eval samples: {len(data['validation'])}"
        )
        print(
            tokenizer.decode(data["train"][0]["input_ids"], skip_special_tokens=False)
        )

        # Choose target modules by familywarmup_ratio
        if "gpt2" in cfg.model.model_name:
            target_modules = ["c_attn", "c_proj"]
        elif "gemma" in cfg.model.model_name.lower():
            # HuggingFace Gemma uses these names
            target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"]
        else:
            target_modules = ["q_proj", "v_proj"]

        base = AutoModelForCausalLM.from_pretrained(
            f"models_dir/{cfg.model.model_name}",
            device_map="auto",
            low_cpu_mem_usage=True,
        )
        if cfg.model.lora_r > 0:
            lora_conf = LoraConfig(
                r=cfg.model.lora_r,
                lora_alpha=cfg.model.lora_alpha_coef * cfg.model.lora_r,
                target_modules=target_modules,
                lora_dropout=cfg.model.dropout,
                bias="none",
                task_type="CAUSAL_LM",
            )
            model = get_peft_model(base, lora_conf)
            print(f"Initialized LoRA model with r={cfg.model.lora_r}")
        else:
            model = base
            print("Training full model without LoRA.")

        data_collator = DataCollatorForLM(tokenizer=tokenizer)

        grad_accum = compute_grad_accum(
            cfg.train.global_batch_size, cfg.train.per_device_batch_size
        )
        warmup_steps = (
            cfg.train.warmup_steps if cfg.train.warmup_steps is not None else 0
        )

        # W&B
        os.environ["WANDB_PROJECT"] = cfg.train.wandb_project

        run_name = make_run_name(cfg)
        output_dir = Path(cfg.train.output_dir) / run_name
        output_dir.mkdir(parents=True, exist_ok=True)
        print(f"Output directory: {output_dir}")
        training_args = TrainingArguments(
            output_dir=str(output_dir),
            per_device_train_batch_size=cfg.train.per_device_batch_size,
            per_device_eval_batch_size=cfg.train.per_device_batch_size,
            gradient_accumulation_steps=grad_accum,
            max_steps=cfg.train.train_steps,
            eval_strategy="steps",
            eval_steps=cfg.train.eval_steps,
            logging_steps=cfg.train.logging_steps,
            save_steps=cfg.train.eval_steps,
            save_total_limit=cfg.train.save_total_limit,
            load_best_model_at_end=True,
            metric_for_best_model="eval_loss",
            greater_is_better=False,
            learning_rate=cfg.train.learning_rate,
            weight_decay=cfg.train.weight_decay,
            lr_scheduler_type=cfg.train.lr_scheduler_type,
            warmup_ratio=cfg.train.warmup_ratio if warmup_steps == 0 else 0.0,
            warmup_steps=warmup_steps,
            bf16=cfg.train.bf16
            and torch.cuda.is_available()
            and torch.cuda.get_device_capability()[0] >= 8,
            fp16=not cfg.train.bf16 and torch.cuda.is_available(),
            report_to=["wandb"],
            remove_unused_columns=False,
            dataloader_pin_memory=True,
            optim="adamw_torch_fused",
            dataloader_num_workers=8,
            adam_beta1=cfg.train.adam_beta1,
            adam_beta2=cfg.train.adam_beta2,
            adam_epsilon=cfg.train.adam_epsilon,
            run_name=run_name,
            torch_compile=False,
            group_by_length=True,
        )

        trainer = CustomLossTrainer(
            model=model,
            args=training_args,
            train_dataset=data["train"],
            eval_dataset=data["validation"],
            tokenizer=tokenizer,
            loss_name=cfg.train.loss_fn,
            data_collator=data_collator,
        )

        # This prints only trainable LoRA params
        if cfg.model.lora_r > 0:
            model.print_trainable_parameters()

        train_result = trainer.train()

        trainer.save_model(
            str(output_dir / "best_model")
        )  # Saves the tokenizer too for easy upload
        tokenizer.save_pretrained(str(output_dir / "best_model"))
        print(f"Model saved to {output_dir / 'best_model'}")

        # Perplexity from losses for quick sanity check
        metrics = train_result.metrics
        if "train_loss" in metrics:
            try:
                metrics["train_perplexity"] = math.exp(metrics["train_loss"])
            except OverflowError:
                metrics["train_perplexity"] = float("inf")
        trainer.log_metrics("train", metrics)
        trainer.save_metrics("train", metrics)
        trainer.save_state()

        eval_metrics = trainer.evaluate()
        if "eval_loss" in eval_metrics:
            try:
                eval_metrics["perplexity"] = math.exp(eval_metrics["eval_loss"])
            except OverflowError:
                eval_metrics["perplexity"] = float("inf")
        trainer.log_metrics("eval", eval_metrics)
        trainer.save_metrics("eval", eval_metrics)

        # End W&B run if active
        if wandb.run is not None:
            wandb.finish()
        # Print best eval loss and corresponding epoch (from trainer state)
        best_eval_loss = trainer.state.best_metric
        best_epoch = None

        if best_eval_loss is None:
            print("No best eval loss found in trainer.state.best_metric")
        else:
            # Search log history for the entry that matches the best eval loss
            for entry in trainer.state.log_history:
                if "eval_loss" in entry and entry["eval_loss"] == best_eval_loss:
                    best_epoch = entry.get("epoch", None)
                    break

            print(f"Best eval loss: {best_eval_loss}")
            if best_epoch is not None:
                print(f"Epoch of best eval loss: {best_epoch}")
            else:
                print("Epoch for best eval loss not found in log history")

    main()
