import torch
import torch.nn as nn
import torch.optim as optim
import wandb

from torchvision import datasets, transforms
from torch.utils.data import DataLoader


# Define the training function
def train_model(model, train_loader, optimizer, lr_scheduler, epoch):
    criterion = nn.CrossEntropyLoss()
    model.train()
    train_loss = 0.0
    train_acc = 0.0
    n_samples = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        n_samples += data.size(0)
        train_loss += loss.item()
        train_acc += (output.argmax(dim=1) == target).sum().item()

    train_loss /= batch_idx + 1
    train_acc /= n_samples
    lr_scheduler.step()
    wandb.log(
        {
            f"loss/train": train_loss,
            f"acc/train": train_acc,
            "lr": optimizer.param_groups[0]["lr"],
        },
        step=epoch + 1,
    )


# Define the training function
def test_model(model, test_loader, epoch):
    criterion = nn.CrossEntropyLoss()
    model.eval()
    val_loss = 0.0
    val_acc = 0.0
    n_samples = 0
    for batch_idx, (data, target) in enumerate(test_loader):
        with torch.no_grad():
            output = model(data)
        loss = criterion(output, target)

        n_samples += data.size(0)
        val_loss += loss.item()
        val_acc += (output.argmax(dim=1) == target).sum().item()

    val_loss /= batch_idx + 1
    val_acc /= n_samples
    wandb.log(
        {
            f"loss/val": val_loss,
            f"acc/val": val_acc,
        },
        step=epoch + 1,
    )


def main(weight_decay):
    import timm

    # Set up the training process

    lr = 1e-3
    EPOCHS = 120
    device = torch.device(
        "cuda:0"
        if torch.cuda.is_available()
        else "mps" if torch.backends.mps.is_available() else "cpu"
    )
    # # ======================== Network ==========================
    model = timm.create_model(
        "vit_tiny_patch16_224.augreg_in21k_ft_in1k",
        pretrained=False,
        num_classes=10,
    ).to(device)

    # # ======================== Dataloader ==========================
    # Load the CIFAR-10 dataset
    train_transform = transforms.Compose(
        [
            transforms.Resize((224, 224)),  # Resize to 224x224 for ViT
            transforms.RandomHorizontalFlip(),  # Random horizontal flip for data augmentation
            transforms.RandomCrop(
                224, padding=4
            ),  # Random crop with padding for data augmentation
            transforms.ToTensor(),  # Convert image to tensor
            transforms.Normalize(
                mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616]
            ),  # Normalization
        ]
    )
    val_transform = transforms.Compose(
        [
            transforms.Resize((224, 224)),  # ViT requires 224x224 images
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616]
            ),
        ]
    )

    train_dataset = datasets.CIFAR10(
        root="./data", train=True, download=True, transform=train_transform
    )
    test_dataset = datasets.CIFAR10(
        root="./data", train=False, download=True, transform=val_transform
    )

    train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
    val_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

    # Define main scheduler
    lr_scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer=optimizer, T_0=EPOCHS, eta_min=1e-6
    )

    for epoch in range(EPOCHS):
        train_model(model, train_loader, optimizer, lr_scheduler, epoch)
        test_model(model, val_loader, epoch)

        # Calculate and log the average value of norm(W_query * W_key) for each layer
        layer_avg_nuc_norm = 0.0
        layer_avg_fro_norm = 0.0
        for name, module in model.named_modules():
            if isinstance(module, timm.models.vision_transformer.Attention):
                layer_idx = int(name.split(".")[-2])

                # Access Q and K matrices
                query_weights = module.qkv.weight[: model.embed_dim, :].T
                key_weights = module.qkv.weight[
                    model.embed_dim : 2 * model.embed_dim, :
                ].T

                QK_value = torch.einsum("de,be -> db", query_weights, key_weights)
                nuc_norm_value = torch.linalg.norm(QK_value, ord="nuc", dim=(0, 1))
                frob_norm_value = torch.linalg.norm(QK_value, ord="fro", dim=(0, 1))

                layer_avg_nuc_norm += nuc_norm_value
                layer_avg_fro_norm += frob_norm_value

        layer_avg_nuc_norm /= layer_idx + 1
        layer_avg_fro_norm /= layer_idx + 1

        print(
            f"Epoch {epoch+1}/{EPOCHS}, avg nuc_norm_value: {layer_avg_nuc_norm}, avg fro_norm_value: {layer_avg_fro_norm}, avg nuc_by_fro_norm_value: {layer_avg_nuc_norm / layer_avg_fro_norm}"
        )


if __name__ == "__main__":
    #
    for weight_decay in [0.5, 0.1, 0.05, 0.01, 0.005, 0.001, 0.0005, 0.0001]:
        main(weight_decay)
