import argparse
import logging
import time
from functools import partial
from pathlib import Path

import git
import neps
import torch
import torch.nn as nn
import torch.optim as optim
import wandb
from torchvision.models import ResNet
from torchvision.models.vision_transformer import Encoder, VisionTransformer

from experiments.resnet.utils import create_model, data_prep_c10, data_prep_c100, validate
from layer_freeze.model_agnostic_freezing import FrozenModel


def training_pipeline(
    epochs: int,
    n_trainable_layers: int,
    learning_rate: float,
    weight_decay: float,
    beta1: float,
    beta2: float,
    optimizer_name: str,
    model_name: str,
    dataset: str,
    batch_size: int,
    dataloader_workers: int,
) -> dict:
    """Main training interface for HPO."""
    wandb.init(
        project=f"layer-freeze-rank-correlation-{model_name}",
        group=f"{n_trainable_layers}_trainable",
        reinit=True,
    )

    wandb.config.update(locals())

    match dataset:
        case "c100":
            trainloader, validloader, _, num_classes = data_prep_c100(
                batch_size=batch_size, dataloader_workers=dataloader_workers
            )
        case "c10":
            trainloader, validloader, _, num_classes = data_prep_c10(
                batch_size=batch_size, dataloader_workers=dataloader_workers
            )
        case _:
            raise ValueError(f"Unsupported dataset: {dataset}")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Define model with new parameters
    model = create_model(model_name=model_name, num_classes=num_classes)

    # freeze layers
    model = FrozenModel(
        n_trainable=n_trainable_layers,
        base_model=model,
        print_summary=False,
        unwrap=(ResNet, VisionTransformer, Encoder),
    )

    # Define loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    match optimizer_name.lower():
        case "adam":
            optimizer = optim.Adam(
                filter(lambda p: p.requires_grad, model.parameters()),
                lr=learning_rate,
                betas=(beta1, beta2),
                weight_decay=weight_decay,
            )
        case "sgd":
            optimizer = optim.SGD(
                filter(lambda p: p.requires_grad, model.parameters()),
                lr=learning_rate,
                momentum=beta1,
                weight_decay=weight_decay,
            )
        case _:
            raise ValueError(f"Unsupported optimizer: {optimizer_name}")

    model = model.to(device)

    # Training loop
    _start = time.time()
    forward_times = []
    backward_times = []
    validation_errors = []
    losses = []
    model.train()
    for epoch in range(epochs):
        tmp_losses = []
        for _, (data, target) in enumerate(trainloader):
            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            data = data.to(device)
            target = target.to(device)

            forward_start = time.time()
            outputs = model(data)
            forward_times.append(time.time() - forward_start)

            loss = criterion(outputs, target)
            wandb.log({"train_loss": loss.cpu().item()})
            tmp_losses.append(loss.cpu().item())
            backward_start = time.time()
            loss.backward()
            optimizer.step()
            backward_times.append(time.time() - backward_start)

            # print statistics
        losses.append(sum(tmp_losses) / len(tmp_losses))
        val_err, _ = validate(model, validloader, device)
        validation_errors.append(val_err)
        wandb.log({"epoch": epoch, "val/error": val_err})
    _end = time.time()

    memory_used = torch.cuda.memory_allocated(device=device) / (1024**2)

    avg_forward_time = sum(forward_times) / len(forward_times)
    avg_backward_time = sum(backward_times) / len(backward_times)

    # Validation loop
    val_err, val_time = validate(model, validloader, device)

    extra = {
        "validation_time": val_time,
        "current_epoch": epochs,
        "gpu_memory_used_mb": memory_used,
        "avg_forward_time_ms": avg_forward_time * 1000,
        "avg_backward_time_ms": avg_backward_time * 1000,
        "n_trainable_params": sum(
            p.numel() for p in filter(lambda p: p.requires_grad, model.parameters())
        ),
        "n_total_params": sum(p.numel() for p in model.parameters()),
        "perc_trainable_params": (
            sum(p.numel() for p in filter(lambda p: p.requires_grad, model.parameters()))
            / sum(p.numel() for p in model.parameters())
        )
        * 100,
    }
    wandb.config.perc_trainable_params = (
        sum(p.numel() for p in filter(lambda p: p.requires_grad, model.parameters()))
        / sum(p.numel() for p in model.parameters())
    ) * 100
    wandb.log(extra)
    wandb.finish()

    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()

    return {
        "objective_to_minimize": val_err,
        "cost": _end - _start,
        "learning_curve": losses,
        "info_dict": {
            **extra,
            "validation_errors": validation_errors,
            "batch_size": batch_size,
        },
    }


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO)
    parser = argparse.ArgumentParser()

    parser.add_argument("--group_name", type=str, default="", help="Group name")
    parser.add_argument(
        "--n_trainable_layers", type=int, default=1, help="Number of layers to train"
    )
    parser.add_argument(
        "--dataloader_workers", type=int, default=2, help="Number of dataloader workers"
    )
    parser.add_argument("--dataset", type=str, default="c100", help="Dataset to use")
    parser.add_argument("--model_name", type=str, default="resnet18", help="Model to use")
    parser.add_argument("--batch_size", type=int, default=1024, help="Batch size")
    parser.add_argument("--epochs", type=int, default=20, help="Number of epochs")
    args = parser.parse_args()

    # Set seeds for reproducibility
    SEED = 42
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    import numpy as np

    np.random.seed(SEED)
    import random

    random.seed(SEED)

    pipeline_space = {
        "learning_rate": neps.Categorical(choices=[1e-4, 1e-3, 1e-2]),
        "beta1": neps.Categorical(choices=[0.9, 0.95, 0.99]),
        "beta2": neps.Categorical(choices=[0.9, 0.95, 0.99]),
        "weight_decay": neps.Categorical(choices=[1e-5, 1e-4, 1e-3, 1e-2]),
        "optimizer_name": neps.Categorical(choices=["adam", "sgd"]),
    }

    root_directory = Path(git.Repo(".", search_parent_directories=True).working_tree_dir) / "output"
    if not root_directory.exists():
        try:
            root_directory.mkdir(parents=True)
        except FileExistsError:
            print("Directory already exists")

    output_tree = f"{args.dataset}/{args.epochs}_epochs/{args.n_trainable_layers}_trainable"

    optimizer = "grid_search"

    neps.run(
        evaluate_pipeline=partial(
            training_pipeline,
            n_trainable_layers=args.n_trainable_layers,
            epochs=args.epochs,
            model_name=args.model_name,
            dataset=args.dataset,
            batch_size=args.batch_size,
            dataloader_workers=args.dataloader_workers,
        ),
        pipeline_space=pipeline_space,
        optimizer=optimizer,
        root_directory=f"{root_directory}/{args.group_name}/{optimizer}/{output_tree}/",
        max_evaluations_total=100,
    )
