from __future__ import annotations

import argparse
import json
import logging
import math
import shutil
import time
from pathlib import Path
from typing import Dict, Optional

import torch
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import get_cosine_schedule_with_warmup
from tqdm.auto import tqdm

try:
    import wandb
except ImportError:
    wandb = None

from data import (
    MaskPruningCollator,
    build_chatml_training_dataset,
    build_training_dataset,
    load_dataset_specs,
)
from models import load_llama3_model
from pruning.ffn_prune import prune_llama_ffn
from utils.config import load_yaml_config

logger = logging.getLogger("ifpruning.train")


def _parse_chatml_overrides(entries):
    parsed = []
    for item in entries:
        parts = item.split(":")
        path = parts[0]
        split = parts[1] if len(parts) > 1 and parts[1] else "train"
        task = parts[2] if len(parts) > 2 and parts[2] else Path(path).stem
        parsed.append({"path": path, "split": split, "task_name": task})
    return parsed


def _sanitize_for_wandb(value):
    if isinstance(value, Path):
        return str(value)
    if isinstance(value, torch.dtype):
        return str(value)
    if isinstance(value, torch.device):
        return str(value)
    if isinstance(value, dict):
        return {str(k): _sanitize_for_wandb(v) for k, v in value.items()}
    if isinstance(value, (list, tuple, set)):
        return [_sanitize_for_wandb(v) for v in value]
    return value


def _is_mask_export(path: Path) -> bool:
    return path.is_dir() and (path / "topk_indices.pt").exists()


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Train an exported (pruned) LLM without masking.")
    parser.add_argument(
        "--config",
        type=Path,
        default=Path("configs/mask_training.yaml"),
        help="Training configuration YAML file.",
    )
    parser.add_argument(
        "--output-dir",
        type=Path,
        default=None,
        help="Optional override for output directory.",
    )
    parser.add_argument(
        "--base-model",
        type=str,
        default=None,
        help="Override model.base_model (useful for export model paths).",
    )
    parser.add_argument(
        "--chatml-file",
        action="append",
        default=None,
        help="Override chatml_files entries. Format: path[:split[:task_name]]",
    )
    parser.add_argument(
        "--val-chatml-file",
        action="append",
        default=None,
        help="Override validation chatml files. Format: path[:split[:task_name]]",
    )
    parser.add_argument(
        "--test-chatml-file",
        action="append",
        default=None,
        help="Override test chatml files. Format: path[:split[:task_name]]",
    )
    parser.add_argument(
        "--wandb-run-name",
        type=str,
        default=None,
        help="Override wandb run name (useful for persona-specific runs).",
    )
    parser.add_argument(
        "--prompt-text-mode",
        type=str,
        default=None,
        choices=["system_user", "system_only", "user_only"],
        help="Override data.prompt_text_mode for input text.",
    )
    parser.add_argument(
        "--llm-prompt-mode",
        type=str,
        default=None,
        choices=["system_user", "system_only", "user_only"],
        help="Override data.llm_prompt_mode for ChatML encoding to LLM.",
    )
    parser.add_argument(
        "--learning-rate",
        type=float,
        default=None,
        help="Override training learning_rate.",
    )
    parser.add_argument(
        "--learning-rate-llm",
        type=float,
        default=None,
        help="Override LLM learning rate (defaults to training.learning_rate).",
    )
    parser.add_argument(
        "--warmup-ratio",
        type=float,
        default=None,
        help="Override training warmup_ratio.",
    )
    parser.add_argument(
        "--warmup-ratio-llm",
        type=float,
        default=None,
        help="Override LLM warmup_ratio (defaults to warmup_ratio).",
    )
    parser.add_argument(
        "--keep-last-epoch-only",
        action="store_true",
        help="After saving the final epoch, remove older epoch_* checkpoints.",
    )
    parser.add_argument(
        "--skip-final-save",
        action="store_true",
        help="Skip saving the final model to the output root (use epoch_* only).",
    )
    return parser.parse_args()


def setup_logging():
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    )


def save_model_checkpoint(
    base_dir: Path,
    model: torch.nn.Module,
    tokenizer,
    epoch: Optional[int] = None,
    metadata: Optional[dict] = None,
) -> None:
    ckpt_dir = base_dir if epoch is None else base_dir / f"epoch_{epoch:02d}"
    ckpt_dir.mkdir(parents=True, exist_ok=True)
    model.save_pretrained(ckpt_dir)
    if tokenizer is not None and hasattr(tokenizer, "save_pretrained"):
        tokenizer.save_pretrained(ckpt_dir)
    if metadata:
        with open(ckpt_dir / "meta.json", "w", encoding="utf-8") as f:
            json.dump(metadata, f, indent=2)


def cleanup_old_epochs(output_dir: Path, keep_epoch: int) -> None:
    keep_name = f"epoch_{keep_epoch:02d}"
    for path in output_dir.glob("epoch_*"):
        if not path.is_dir():
            continue
        if path.name == keep_name:
            continue
        shutil.rmtree(path, ignore_errors=True)


def main() -> None:
    setup_logging()
    args = parse_args()
    cfg = load_yaml_config(args.config)
    if args.chatml_file or args.val_chatml_file or args.test_chatml_file:
        cfg.setdefault("data", {})
        if args.chatml_file:
            cfg["data"]["chatml_files"] = _parse_chatml_overrides(args.chatml_file)
        if args.val_chatml_file:
            cfg["data"]["chatml_val_files"] = _parse_chatml_overrides(args.val_chatml_file)
        if args.test_chatml_file:
            cfg["data"]["chatml_test_files"] = _parse_chatml_overrides(args.test_chatml_file)

    if "model" in cfg:
        model_cfg = dict(cfg["model"])
    else:
        model_cfg_path = cfg.get("model_config", "configs/model.yaml")
        model_cfg = load_yaml_config(model_cfg_path)

    mask_export_dir = None
    if args.base_model:
        candidate = Path(args.base_model)
        if _is_mask_export(candidate):
            mask_export_dir = candidate
            meta_path = mask_export_dir / "mask_export.json"
            if meta_path.exists():
                with meta_path.open("r", encoding="utf-8") as f:
                    meta = json.load(f)
                base_model_override = meta.get("base_model")
                if base_model_override:
                    model_cfg["base_model"] = base_model_override
        else:
            model_cfg["base_model"] = args.base_model
    logger.info("Loading base model from %s", model_cfg.get("base_model"))
    llama_bundle = load_llama3_model(model_cfg)
    if mask_export_dir is not None:
        topk_indices = torch.load(mask_export_dir / "topk_indices.pt", map_location="cpu")
        prune_llama_ffn(llama_bundle.model, topk_indices)
        logger.info("Applied pruning from mask-only export: %s", mask_export_dir)
    model_device = next(llama_bundle.model.parameters()).device

    freeze_base_model = bool(model_cfg.get("freeze_base_model", False))
    if freeze_base_model:
        for param in llama_bundle.model.parameters():
            param.requires_grad_(False)
        llama_bundle.model.eval()
        logger.info("Base model parameters frozen.")
    else:
        llama_bundle.model.train()
        logger.info("Base model parameters will be updated during training.")

    data_cfg = cfg.get("data", {})
    if args.prompt_text_mode is not None:
        data_cfg = dict(data_cfg)
        data_cfg["prompt_text_mode"] = args.prompt_text_mode
    if args.llm_prompt_mode is not None:
        data_cfg = dict(data_cfg)
        data_cfg["llm_prompt_mode"] = args.llm_prompt_mode
    system_prompt = data_cfg.get(
        "system_prompt",
        "You are a helpful assistant.",
    )
    cutoff_len = int(data_cfg.get("cutoff_len", 2048))
    prompt_max_length = int(data_cfg.get("prompt_max_length", 512))
    prompt_text_mode = data_cfg.get("prompt_text_mode", "system_user")
    llm_prompt_mode = data_cfg.get("llm_prompt_mode", prompt_text_mode)
    append_eos = bool(data_cfg.get("append_eos_to_response", False))
    seed = int(data_cfg.get("seed", 42))

    logger.info("Building training dataset...")
    chatml_files = data_cfg.get("chatml_files")
    chatml_val_files = data_cfg.get("chatml_val_files")
    if chatml_files:
        dataset = build_chatml_training_dataset(
            files=chatml_files,
            tokenizer=llama_bundle.tokenizer,
            cutoff_len=cutoff_len,
            prompt_max_length=prompt_max_length,
            default_system_prompt=system_prompt,
            prompt_text_mode=prompt_text_mode,
            llm_prompt_mode=llm_prompt_mode,
            append_eos=append_eos,
            seed=seed,
        )
        dataset_names = [
            item.get("task_name") or Path(item["path"]).stem for item in chatml_files
        ]
    else:
        specs = load_dataset_specs(data_cfg)
        dataset = build_training_dataset(
            specs=specs,
            tokenizer=llama_bundle.tokenizer,
            system_prompt=system_prompt,
            cutoff_len=cutoff_len,
            prompt_max_length=prompt_max_length,
            seed=seed,
        )
        dataset_names = [spec.name for spec in specs]

    val_dataset = None
    test_dataset = None
    if chatml_val_files:
        val_dataset = build_chatml_training_dataset(
            files=chatml_val_files,
            tokenizer=llama_bundle.tokenizer,
            cutoff_len=cutoff_len,
            prompt_max_length=prompt_max_length,
            default_system_prompt=system_prompt,
            prompt_text_mode=prompt_text_mode,
            llm_prompt_mode=llm_prompt_mode,
            append_eos=append_eos,
            seed=seed,
        )
        logger.info("Loaded %d validation samples.", len(val_dataset))
    chatml_test_files = data_cfg.get("chatml_test_files")
    if chatml_test_files:
        test_dataset = build_chatml_training_dataset(
            files=chatml_test_files,
            tokenizer=llama_bundle.tokenizer,
            cutoff_len=cutoff_len,
            prompt_max_length=prompt_max_length,
            default_system_prompt=system_prompt,
            prompt_text_mode=prompt_text_mode,
            llm_prompt_mode=llm_prompt_mode,
            append_eos=append_eos,
            seed=seed,
        )
        logger.info("Loaded %d test samples.", len(test_dataset))
    if len(dataset) == 0:
        raise RuntimeError("No training samples were created. Please check dataset configuration.")
    logger.info("Loaded %d training samples.", len(dataset))
    task_counts = {}
    for sample in dataset.samples:
        task_counts[sample.task] = task_counts.get(sample.task, 0) + 1
    logger.info("Task distribution: %s", task_counts)

    collator = MaskPruningCollator(
        pad_token_id=llama_bundle.tokenizer.pad_token_id or llama_bundle.tokenizer.eos_token_id,
        label_pad_token_id=-100,
        pad_to_multiple_of=8,
    )

    train_batch_size = int(cfg["training"]["batch_size"])
    num_workers = int(data_cfg.get("num_workers", 0))
    train_loader = DataLoader(
        dataset,
        batch_size=train_batch_size,
        shuffle=True,
        collate_fn=collator,
        num_workers=num_workers,
        pin_memory=True,
    )
    val_loader = None
    if val_dataset is not None and len(val_dataset) > 0:
        val_loader = DataLoader(
            val_dataset,
            batch_size=train_batch_size,
            shuffle=False,
            collate_fn=collator,
            num_workers=num_workers,
            pin_memory=True,
        )

    test_loader = None
    if test_dataset is not None and len(test_dataset) > 0:
        test_loader = DataLoader(
            test_dataset,
            batch_size=train_batch_size,
            shuffle=False,
            collate_fn=collator,
            num_workers=num_workers,
            pin_memory=True,
        )

    train_cfg = cfg["training"]
    num_epochs = int(train_cfg.get("num_epochs", 1))
    grad_accum = int(train_cfg.get("grad_accumulation_steps", 1))
    learning_rate = float(train_cfg.get("learning_rate", 5e-5))
    if args.learning_rate is not None:
        learning_rate = float(args.learning_rate)
    learning_rate_llm = float(args.learning_rate_llm) if args.learning_rate_llm is not None else learning_rate
    weight_decay = float(train_cfg.get("weight_decay", 0.0))
    warmup_ratio = float(train_cfg.get("warmup_ratio", 0.0))
    if args.warmup_ratio is not None:
        warmup_ratio = float(args.warmup_ratio)
    warmup_ratio_llm = warmup_ratio
    if args.warmup_ratio_llm is not None:
        warmup_ratio_llm = float(args.warmup_ratio_llm)
    max_grad_norm = float(train_cfg.get("max_grad_norm", 1.0))
    logging_steps = int(train_cfg.get("logging_steps", 10))
    save_each_epoch = bool(train_cfg.get("save_each_epoch", False))

    wandb_cfg = train_cfg.get("wandb", {})
    if args.wandb_run_name:
        wandb_cfg = dict(wandb_cfg)
        wandb_cfg["run_name"] = args.wandb_run_name
    use_wandb = bool(wandb_cfg.get("enable", False))
    wandb_run = None
    if use_wandb:
        if wandb is None:
            logger.warning("wandb logging requested but wandb is not installed. Disable wandb or install the package.")
        else:
            resolved_cfg = {
                "model": _sanitize_for_wandb(model_cfg),
                "data": _sanitize_for_wandb(data_cfg),
                "training": _sanitize_for_wandb(train_cfg),
            }
            wandb_run = wandb.init(
                project=wandb_cfg.get("project", "mask-pruning"),
                name=wandb_cfg.get("run_name"),
                entity=wandb_cfg.get("entity"),
                tags=wandb_cfg.get("tags", None),
                config={
                    "learning_rate": learning_rate_llm,
                    "batch_size": train_batch_size,
                    "grad_accum": grad_accum,
                    "num_epochs": num_epochs,
                    "weight_decay": weight_decay,
                    "warmup_ratio": warmup_ratio_llm,
                    "cutoff_len": cutoff_len,
                    "prompt_max_length": prompt_max_length,
                    "datasets": dataset_names,
                    "task_distribution": task_counts,
                    "config_file": str(args.config),
                    "config_raw": _sanitize_for_wandb(cfg),
                    "config_resolved": resolved_cfg,
                    "args": _sanitize_for_wandb(vars(args)),
                },
            )

    llm_params = [p for p in llama_bundle.model.parameters() if p.requires_grad]
    if len(llm_params) == 0:
        raise RuntimeError("No trainable parameters found for optimization.")

    optimizer_llm = AdamW(llm_params, lr=learning_rate_llm, weight_decay=weight_decay)
    total_steps = math.ceil(len(train_loader) / grad_accum) * num_epochs
    warmup_steps = int(total_steps * warmup_ratio_llm)
    scheduler_llm = get_cosine_schedule_with_warmup(optimizer_llm, warmup_steps, total_steps)

    output_dir = Path(args.output_dir or train_cfg.get("output_dir", "outputs/llm_export"))
    output_dir.mkdir(parents=True, exist_ok=True)

    def run_eval(loader):
        if loader is None:
            return None
        was_llm_train = llama_bundle.model.training
        llama_bundle.model.eval()
        total_loss = 0.0
        batches = 0
        with torch.no_grad():
            for batch in loader:
                batch.pop("prompt_text")
                batch.pop("task", None)
                inputs = {k: v.to(model_device) for k, v in batch.items()}
                outputs = llama_bundle.model(**inputs)
                lm_loss = outputs["loss"]
                total_loss += lm_loss.item()
                batches += 1
        if was_llm_train:
            llama_bundle.model.train()
        return total_loss / max(batches, 1)

    def compute_grad_norm(params) -> float:
        total_sq = 0.0
        for p in params:
            if p.grad is None:
                continue
            param_norm = p.grad.data.norm(2).item()
            total_sq += param_norm * param_norm
        return total_sq ** 0.5

    global_step = 0
    optimizer_llm.zero_grad()
    training_start = time.perf_counter()
    ema_loss = None

    for epoch in range(num_epochs):
        logger.info("Starting epoch %d / %d", epoch + 1, num_epochs)
        epoch_task_counts = {}
        running_loss = 0.0
        micro_batches = 0
        epoch_desc = f"Epoch {epoch + 1}/{num_epochs}"
        for step, batch in enumerate(tqdm(train_loader, desc=epoch_desc, leave=False), start=1):
            batch.pop("prompt_text")
            tasks = batch.pop("task")
            inputs = {k: v.to(model_device) for k, v in batch.items()}
            outputs = llama_bundle.model(**inputs)
            lm_loss = outputs["loss"]
            loss = lm_loss / grad_accum
            loss.backward()
            running_loss += loss.item()
            micro_batches += 1
            for t in tasks:
                epoch_task_counts[t] = epoch_task_counts.get(t, 0) + 1

            if step % grad_accum == 0:
                next_step = global_step + 1
                log_grad = (next_step % logging_steps == 0)
                grad_norm_llm = compute_grad_norm(llm_params) if log_grad else None

                torch.nn.utils.clip_grad_norm_(llm_params, max_grad_norm)
                optimizer_llm.step()
                scheduler_llm.step()
                optimizer_llm.zero_grad()
                global_step += 1

                if global_step % logging_steps == 0:
                    llm_lr = scheduler_llm.get_last_lr()[0]
                    log_entry: Dict[str, float] = {
                        "global_step": global_step,
                        "loss": loss.item() * grad_accum,
                        "language_model_loss": lm_loss.item(),
                        "learning_rate": llm_lr,
                        "learning_rate/llm": llm_lr,
                    }
                    if grad_norm_llm is not None:
                        log_entry["grad_norm/llm"] = grad_norm_llm
                    ema_loss = loss.item() * grad_accum if ema_loss is None else (
                        0.98 * ema_loss + 0.02 * (loss.item() * grad_accum)
                    )
                    log_entry["loss/ema"] = ema_loss
                    for task_name, count in epoch_task_counts.items():
                        log_entry[f"train/task_count/{task_name}"] = count
                    logger.info("Step %s", json.dumps(log_entry))
                    if wandb_run is not None:
                        wandb_run.log(log_entry, step=global_step)

        logger.info("Completed epoch %d", epoch + 1)
        epoch_loss_value = (running_loss / max(micro_batches, 1)) * grad_accum
        logger.info("Epoch %d summary: loss=%.4f, tasks=%s", epoch + 1, epoch_loss_value, epoch_task_counts)
        val_loss = run_eval(val_loader)
        if val_loss is not None:
            logger.info("Validation loss after epoch %d: %.4f", epoch + 1, val_loss)
        test_loss = run_eval(test_loader)
        if test_loss is not None:
            logger.info("Test loss after epoch %d: %.4f", epoch + 1, test_loss)
        if wandb_run is not None:
            epoch_log = {
                "epoch": epoch + 1,
                "epoch_loss": epoch_loss_value,
            }
            if val_loss is not None:
                epoch_log["validation_loss"] = val_loss
            if test_loss is not None:
                epoch_log["test_loss"] = test_loss
            for task_name, count in epoch_task_counts.items():
                epoch_log[f"epoch/task_count/{task_name}"] = count
            wandb_run.log(epoch_log, step=global_step)

        if save_each_epoch:
            epoch_metadata = {
                "epoch": epoch + 1,
                "epoch_loss": epoch_loss_value,
                "task_counts": epoch_task_counts,
                "global_step": global_step,
            }
            if test_loss is not None:
                epoch_metadata["test_loss"] = test_loss
            save_model_checkpoint(output_dir, llama_bundle.model, llama_bundle.tokenizer, epoch=epoch + 1, metadata=epoch_metadata)
            if args.keep_last_epoch_only and (epoch + 1 == num_epochs):
                logger.info("Removing older epoch checkpoints (keeping epoch_%02d).", epoch + 1)
                cleanup_old_epochs(output_dir, keep_epoch=epoch + 1)

    total_seconds = time.perf_counter() - training_start
    final_metadata = {
        "num_epochs": num_epochs,
        "total_steps": global_step,
        "elapsed_seconds": total_seconds,
    }
    if args.skip_final_save:
        if not save_each_epoch:
            logger.info(
                "Training complete. save_each_epoch is false; saving final model to epoch_%02d (elapsed %.2f s).",
                num_epochs,
                total_seconds,
            )
            save_model_checkpoint(
                output_dir,
                llama_bundle.model,
                llama_bundle.tokenizer,
                epoch=num_epochs,
                metadata=final_metadata,
            )
        else:
            logger.info(
                "Training complete. Skipping final root checkpoint save (elapsed %.2f s).",
                total_seconds,
            )
    else:
        logger.info("Training complete. Saving final model to %s (elapsed %.2f s)", output_dir, total_seconds)
        save_model_checkpoint(output_dir, llama_bundle.model, llama_bundle.tokenizer, metadata=final_metadata)
    if wandb_run is not None:
        wandb_run.finish()


if __name__ == "__main__":
    main()
