import time

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms


def validate(
    model: torch.nn.Module,
    validloader: torch.utils.data.DataLoader,
    device: torch.device,
) -> tuple[float, float]:
    """Validate the model on the validation set."""
    model.eval()
    total = 0
    correct = 0
    with torch.no_grad():
        start = time.perf_counter()
        for data, target in validloader:
            data, target = data.to(device), target.to(device)

            outputs = model(data)
            _, predicted = torch.max(outputs.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
        end = time.perf_counter()
    model.train()
    return 1 - (correct / total), end - start


def create_model(model_name: str, num_classes: int = 10) -> nn.Module:
    match model_name:
        case "resnet18":
            model = torchvision.models.resnet18(weights=None)
        case "resnet34":
            model = torchvision.models.resnet34(weights=None)
        case "resnet50":
            model = torchvision.models.resnet50(weights=None)
        case "vit_b_16":
            model = torchvision.models.vit_b_16(weights=None)
        case "vit_b_32":
            model = torchvision.models.vit_b_32(weights=None)
        case "vit_l_16":
            model = torchvision.models.vit_l_16(weights=None)
        case _:
            raise ValueError(f"Unsupported model: {model_name}")
    if model_name.startswith("vit"):
        model.heads.head = nn.Linear(model.heads.head.in_features, num_classes)
    else:
        model.fc = nn.Linear(model.fc.in_features, num_classes)
    return model


def data_prep_c100(
    batch_size: int,
    get_val_set: bool = True,
    dataloader_workers: int = 4,
    prefetch_factor: int | None = None,
) -> tuple:
    """Prepare CIFAR100 dataset for training and testing."""
    # Define dataset specific transforms and classes
    transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(degrees=15),
            transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
            transforms.RandomPerspective(distortion_scale=0.2, p=0.5),
            transforms.Resize((224, 224)),
        ]
    )
    dataset_class = torchvision.datasets.CIFAR100

    trainset = dataset_class(root="./data", train=True, download=True, transform=transform)
    testset = dataset_class(root="./data", train=False, download=True, transform=transform)

    if get_val_set:
        train_size = len(trainset) - 5000  # Reserve 5k samples for validation
        train_set, val_set = torch.utils.data.random_split(trainset, [train_size, 5000])
        validloader = torch.utils.data.DataLoader(
            val_set,
            batch_size=batch_size,
            shuffle=False,
            num_workers=dataloader_workers,
            persistent_workers=True,
            pin_memory=False,
            prefetch_factor=prefetch_factor,
        )
    else:
        train_set = trainset
        validloader = None

    trainloader = torch.utils.data.DataLoader(
        train_set,
        batch_size=batch_size,
        shuffle=True,
        num_workers=dataloader_workers,
        persistent_workers=True,
        pin_memory=False,
        prefetch_factor=prefetch_factor,
    )

    testloader = torch.utils.data.DataLoader(
        testset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=dataloader_workers,
        persistent_workers=True,
        pin_memory=False,
        prefetch_factor=prefetch_factor,
    )

    return trainloader, validloader, testloader, 100


def data_prep_c10(
    batch_size: int,
    get_val_set: bool = True,
    dataloader_workers: int = 4,
    prefetch_factor: int | None = None,
) -> tuple:
    """Prepare CIFAR10 dataset for training and testing."""
    # Define dataset specific transforms and classes
    transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(degrees=15),
            transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
            transforms.RandomPerspective(distortion_scale=0.2, p=0.5),
        ]
    )
    dataset_class = torchvision.datasets.CIFAR10

    trainset = dataset_class(root="./data", train=True, download=True, transform=transform)
    testset = dataset_class(root="./data", train=False, download=True, transform=transform)

    if get_val_set:
        train_size = len(trainset) - 10000  # Reserve 10k samples for validation
        train_set, val_set = torch.utils.data.random_split(trainset, [train_size, 10000])
        validloader = torch.utils.data.DataLoader(
            val_set,
            batch_size=batch_size,
            shuffle=False,
            num_workers=dataloader_workers,
            persistent_workers=True,
            pin_memory=False,
            prefetch_factor=prefetch_factor,
        )
    else:
        train_set = trainset
        validloader = None

    trainloader = torch.utils.data.DataLoader(
        train_set,
        batch_size=batch_size,
        shuffle=True,
        num_workers=dataloader_workers,
        persistent_workers=True,
        pin_memory=False,
        prefetch_factor=prefetch_factor,
    )
    testloader = torch.utils.data.DataLoader(
        testset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=dataloader_workers,
        persistent_workers=True,
        pin_memory=False,
        prefetch_factor=prefetch_factor,
    )

    return trainloader, validloader, testloader, 10


def full_fidelity_training(
    epochs: int = 10,
    learning_rate: float = 0.008,
    weight_decay: float = 0.01,
    beta1: float = 0.9,
    beta2: float = 0.999,
    optimizer: str = "adam",
    train_dataloader: torch.utils.data.DataLoader | None = None,
    valid_dataloader: torch.utils.data.DataLoader | None = None,
    num_classes: int = 10,
    device: torch.device = None,
) -> dict:
    """Main training interface for HPO."""
    # Prepare data
    # Define model with new parameters
    model = create_model(num_classes=num_classes)

    # Define loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer: torch.optim.Optimizer

    match optimizer:
        case "adam":
            optimizer = optim.Adam(
                model.parameters(),
                lr=learning_rate,
                betas=[beta1, beta2],
                weight_decay=weight_decay,
            )
        case "sgd":
            optimizer = optim.SGD(
                model.parameters(),
                lr=learning_rate,
                momentum=beta1,
                weight_decay=weight_decay,
            )
        case _:
            raise ValueError(f"Unsupported optimizer: {optimizer}")

    if device is not None:
        model = model.to(device)

    # Training loop
    _start = time.time()
    model.train()
    for _ in range(epochs):
        for data, target in train_dataloader:
            optimizer.zero_grad()

            # forward + backward + optimize
            if device is not None:
                data = data.to(device)
                target = target.to(device)

            outputs = model(data)
            loss: torch.Tensor = criterion(outputs, target)
            loss.backward()
            optimizer.step()

    _end = time.time()

    val_err, val_cost = validate(model, valid_dataloader, device)
    return {
        "full_fidelity/val_err": val_err,
        "full_fidelity/val_cost": val_cost,
        "full_fidelity/cost": _end - _start,
    }
