import json
import logging
import os
from pathlib import Path
from time import perf_counter

import hydra
import torch
from omegaconf import DictConfig, OmegaConf
from torch.utils.data import DataLoader, TensorDataset
from torch.optim.lr_scheduler import LinearLR, SequentialLR

from model.guided_lm import DataSplitClassifier
from trainer.utils import seed_everything
logger = logging.getLogger(__name__)


def _get_device(device_pref: str) -> torch.device:
    if device_pref != "auto":
        return torch.device(device_pref)
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")


def _resolve_dtype(dtype_str: str) -> torch.dtype:
    if dtype_str is None:
        raise ValueError("Classifier dtype must be specified in config under classifier.dtype.")
    normalized = str(dtype_str).replace("torch.", "").lower()
    dtype_map = {
        "float32": torch.float32,
        "fp32": torch.float32,
        "float": torch.float32,
        "float64": torch.float64,
        "double": torch.float64,
        "float16": torch.float16,
        "fp16": torch.float16,
        "half": torch.float16,
        "bfloat16": torch.bfloat16,
        "bf16": torch.bfloat16,
    }
    torch_dtype = dtype_map.get(normalized, getattr(torch, normalized, None))
    if not isinstance(torch_dtype, torch.dtype):
        raise ValueError(f"Unsupported classifier dtype '{dtype_str}'.")
    if not torch_dtype.is_floating_point:
        raise ValueError(f"Classifier dtype '{dtype_str}' must be a floating point type.")
    return torch_dtype


@hydra.main(version_base=None, config_path="../configs", config_name="unlearn_T3_preprocessed.yaml")
def main(cfg: DictConfig) -> None:
    seed_everything(cfg.training.seed)
    device = _get_device(cfg.training.device)
    classifier_dtype = _resolve_dtype(cfg.classifier.dtype)
    performance_mode = bool(getattr(cfg.training, "performance_mode", False))
    if performance_mode and device.type == "cuda":
        torch.backends.cudnn.benchmark = True
    logger.info("Using device %s with classifier dtype %s", device, classifier_dtype)
    payload = torch.load(cfg.data.precomputed_path, map_location="cpu")
    features = payload["pooled_states"]
    if features.dtype != classifier_dtype:
        logger.warning(
            "Pooled states dtype %s does not match classifier dtype %s; casting pooled states.",
            features.dtype,
            classifier_dtype,
        )
        features = features.to(classifier_dtype)

    labels = payload["classifier_labels"].to(classifier_dtype).view(-1)
    token_ids = payload["token_ids"].long().view(-1)
    logger.info(
        "Loaded precomputed tensors | pooled_states shape=%s dtype=%s | labels shape=%s dtype=%s | token_ids shape=%s dtype=%s",
        tuple(features.shape),
        features.dtype,
        labels.shape,
        labels.dtype,
        token_ids.shape,
        token_ids.dtype,
    )
    output_dim = payload.get("output_dim", None)
    if output_dim is None:
        raise ValueError("Precomputed payload missing output_dim; please regenerate preprocessing artifacts.")
    
    base_lm_config = payload.get("base_lm_config")
    if base_lm_config:
        logger.info("Loaded precomputed states for base LM: %s", base_lm_config.get("_name_or_path", "unknown"))
    dataset = TensorDataset(features, token_ids, labels)
    dataloader = DataLoader(
        dataset,
        batch_size=cfg.training.batch_size,
        shuffle=True,
        num_workers=cfg.training.num_workers,
        pin_memory=device.type == "cuda",
        persistent_workers=performance_mode and cfg.training.num_workers > 0,
    )

    model = DataSplitClassifier(
        input_dim=features.shape[-1],
        output_dim=output_dim,
        hidden_size=cfg.classifier.hidden_size,
        num_hidden_layers=cfg.classifier.num_hidden_layers,
        activation_str=cfg.classifier.activation_str,
        bias=cfg.classifier.bias,
    ).to(device=device, dtype=classifier_dtype)

    criterion = torch.nn.BCEWithLogitsLoss()
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=cfg.training.learning_rate,
        weight_decay=cfg.training.weight_decay
    )

    num_batches = len(dataloader)
    total_steps = cfg.training.epochs * num_batches
    if total_steps == 0:
        raise ValueError("No training steps available; check dataset size or epoch count.")
    warmup_steps = min(total_steps, cfg.training.warmup_epochs * num_batches)
    decay_steps = total_steps - warmup_steps
    schedulers = []
    if warmup_steps > 0:
        schedulers.append(
            LinearLR(
                optimizer,
                start_factor=1.0/warmup_steps,
                end_factor=1.0,
                total_iters=warmup_steps,
            )
        )
        for param_group in optimizer.param_groups:
            param_group["lr"] = 0.0
    if decay_steps > 0:
        schedulers.append(
            LinearLR(
                optimizer,
                start_factor=1.0,
                end_factor=0.0,
                total_iters=decay_steps,
            )
        )
    scheduler = None
    if schedulers:
        scheduler = schedulers[0] if len(schedulers) == 1 else SequentialLR(
            optimizer,
            schedulers,
            milestones=[warmup_steps],
        )
    logger.info(
        "LR schedule | total_steps=%d warmup_steps=%d decay_steps=%d",
        total_steps,
        warmup_steps,
        decay_steps,
    )

    training_start = perf_counter()
    non_blocking = device.type == "cuda"
    for epoch in range(1, cfg.training.epochs + 1):
        model.train()
        epoch_loss = 0.0 if not performance_mode else None
        for step, (batch_x, batch_token_ids, batch_y) in enumerate(dataloader, 1):
            batch_x = batch_x.to(device, non_blocking=non_blocking).unsqueeze(1)
            logits = model(batch_x).squeeze(1)
            batch_token_ids = batch_token_ids.to(device, non_blocking=non_blocking)
            gathered_logits = logits.gather(1, batch_token_ids.unsqueeze(1)).squeeze(1)
            targets = batch_y.to(device, non_blocking=non_blocking)
            loss = criterion(gathered_logits, targets)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if scheduler is not None:
                scheduler.step()
            if epoch_loss is not None:
                epoch_loss += loss.item()
            if (
                not performance_mode
                and cfg.training.log_interval
                and step % cfg.training.log_interval == 0
            ):
                logger.info("Epoch %d Step %d Loss %.4f", epoch, step, loss.item())

        if epoch_loss is not None:
            logger.info("Epoch %d Mean Loss %.4f", epoch, epoch_loss / len(dataloader))
    training_runtime = perf_counter() - training_start
    logger.info("Training completed in %.2f seconds", training_runtime)

    output_dir = Path(cfg.save.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    save_path = output_dir / cfg.save.filename
    torch.save(
        {
            "state_dict": model.state_dict(),
            "config": OmegaConf.to_container(cfg, resolve=True),
            "output_dim": output_dim,
            "base_lm_config": base_lm_config,
            "training_runtime_seconds": training_runtime,
            "performance_mode": performance_mode,
        },
        save_path,
    )
    logger.info("Saved checkpoint to %s", save_path)

    performance_payload = {
        "training_runtime": training_runtime,
        "performance_mode": performance_mode,
    }
    with open(os.path.join(output_dir, "performance.json"), "w", encoding="utf-8") as f:
        json.dump(performance_payload, f)

if __name__ == "__main__":
    main()