import time

import torch
import torch.nn as nn
import torch.optim as optim
import wandb
from torchvision.models import ResNet

from experiments.resnet.utils import create_model, data_prep_c10, data_prep_c100, validate
from layer_freeze.model_agnostic_freezing import FrozenModel


def training_pipeline(
    learning_rate: float,
    weight_decay: float,
    beta1: float,
    beta2: float,
    n_trainable: int,
    optimizer_name: str,
    batch_size: int,
    num_dataloader_workers: int,
    dataset: str,
    model_name: str,
    optimizer: str,
    epochs: int,
    fidelity: str,
) -> dict:
    """Main training interface for HPO."""
    wandb.init(project=f"hpo-{optimizer}-{dataset}-{model_name}-{fidelity}", reinit=True)
    wandb.config.update(locals())
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Prepare data
    match dataset:
        case "c100":
            trainloader, validloader, _, num_classes = data_prep_c100(
                batch_size=batch_size, dataloader_workers=num_dataloader_workers
            )
        case "c10":
            trainloader, validloader, _, num_classes = data_prep_c10(
                batch_size=batch_size, dataloader_workers=num_dataloader_workers
            )
        case _:
            raise ValueError(f"Unsupported dataset: {dataset}")

    # Define model with new parameters
    model = create_model(model_name=model_name, num_classes=num_classes)

    # freeze layers
    model = FrozenModel(
        n_trainable=n_trainable, base_model=model, print_summary=False, unwrap=(ResNet,)
    )

    # Define loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    match optimizer_name.lower():
        case "adam":
            optimizer = optim.Adam(
                filter(lambda p: p.requires_grad, model.parameters()),
                lr=learning_rate,
                betas=(beta1, beta2),
                weight_decay=weight_decay,
            )
        case "sgd":
            optimizer = optim.SGD(
                filter(lambda p: p.requires_grad, model.parameters()),
                lr=learning_rate,
                momentum=beta1,
                weight_decay=weight_decay,
            )
        case _:
            raise ValueError(f"Unsupported optimizer: {optimizer_name}")

    model = model.to(device)

    _start = time.time()
    forward_times = []
    backward_times = []
    losses = []
    validation_errors = []
    model.train()
    for epoch in range(epochs):
        for _, (data, target) in enumerate(trainloader):
            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            data = data.to(device)
            target = target.to(device)

            forward_start = time.time()
            outputs = model(data)
            forward_times.append(time.time() - forward_start)

            loss = criterion(outputs, target)
            wandb.log({"train/loss": loss.cpu().item()})
            losses.append(loss.cpu().item())
            backward_start = time.time()
            loss.backward()
            optimizer.step()
            backward_times.append(time.time() - backward_start)

        val_err, _ = validate(model, validloader, device)
        validation_errors.append(val_err)
        wandb.log({"train/epoch": epoch, "val/error": val_err})
    _end = time.time()

    memory_used = torch.cuda.memory_allocated(device=device) / (1024**2)

    avg_forward_time = sum(forward_times) / len(forward_times)
    avg_backward_time = sum(backward_times) / len(backward_times)

    # Validation loop
    val_err, val_time = validate(model, validloader, device)
    extra = {
        "val/time": val_time,
        "current_epoch": epochs,
        "gpu_memory_used_mb": memory_used,
        "train/avg_forward_time_ms": avg_forward_time * 1000,
        "train/avg_backward_time_ms": avg_backward_time * 1000,
        "train/avg_loop_time_ms": (avg_forward_time + avg_backward_time) * 1000,
        "train/n_trainable_params": sum(
            p.numel() for p in filter(lambda p: p.requires_grad, model.parameters())
        ),
        "train/n_total_params": sum(p.numel() for p in model.parameters()),
        "train/n_trainable": model.n_trainable,
        "train/perc_trainable_params": (
            sum(p.numel() for p in filter(lambda p: p.requires_grad, model.parameters()))
            / sum(p.numel() for p in model.parameters())
        )
        * 100,
    }
    wandb.log(extra)
    wandb.finish()

    return {
        "objective_to_minimize": val_err,  # type: ignore
        "cost": _end - _start,  # type: ignore
        "learning_curve": losses,  # type: ignore
        "info_dict": {**extra, "validation_errors": validation_errors},
    }
