# Code based on kentaroy47/vision-transformers-cifar10
from __future__ import print_function

import argparse
import csv
import os
import random
import time

import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

from adversarial_superposition.cifar.utils.model import ViT as OriginalViT
from adversarial_superposition.cifar.utils.training import RandAugment
from adversarial_superposition.cifar.utils.utils import progress_bar

# parsers
parser = argparse.ArgumentParser(description="PyTorch CIFAR10 Training")
parser.add_argument(
    "--lr", default=1e-4, type=float, help="learning rate"
)  # resnets.. 1e-3, Vit..1e-4
parser.add_argument("--opt", default="adam")
parser.add_argument(
    "--resume", "-r", action="store_true", help="resume from checkpoint"
)
parser.add_argument("--noaug", action="store_false", help="disable use randomaug")
parser.add_argument(
    "--noamp",
    action="store_true",
    help="disable mixed precision training. for older pytorch versions",
)
parser.add_argument("--nowandb", action="store_true", help="disable wandb")
parser.add_argument("--mixup", action="store_true", help="add mixup augumentations")
parser.add_argument("--net", default="vit")
parser.add_argument("--dp", action="store_true", help="use data parallel")
parser.add_argument("--bs", default="512")
parser.add_argument("--size", default="32")
parser.add_argument("--n_epochs", type=int, default="200")
parser.add_argument("--patch", default="4", type=int, help="patch for ViT")
parser.add_argument("--dimhead", default="512", type=int)
parser.add_argument(
    "--convkernel", default="8", type=int, help="parameter for convmixer"
)
parser.add_argument(
    "--seed", default=10, type=int, help="random seed"
)  # Added seed argument
parser.add_argument(
    "--training_stage",
    type=str,
    default="pretrain",
    choices=["pretrain", "finetune", "finetune_full"],
    help="Training stage: pretrain (from scratch), finetune (freeze base, train head), finetune_full (train all with loaded weights)",
)
parser.add_argument(
    "--pretrained_ckpt",
    type=str,
    default=None,
    help="Path to pretrained checkpoint for finetuning stages",
)
parser.add_argument(
    "--run_name",
    type=str,
    default="",
    help="Custom name prefix for checkpoint and log files",
)


args = parser.parse_args()


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    cudnn.deterministic = True
    cudnn.benchmark = False  # Set benchmark to False for reproducibility


# Set seed
set_seed(args.seed)

# take in args
usewandb = ~args.nowandb
if usewandb:
    import wandb

    # Include seed and stage in wandb run name
    watermark = "{}_{}_lr{}_seed{}".format(
        args.net, args.training_stage, args.lr, args.seed
    )
    if args.run_name:
        watermark = f"{args.run_name}_{watermark}"  # Prepend run_name if provided
    wandb.init(project="cifar10-challange", name=watermark)
    wandb.config.update(args)

bs = int(args.bs)
imsize = int(args.size)

use_amp = not args.noamp
aug = args.noaug

device = "cuda" if torch.cuda.is_available() else "cpu"
best_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch

# Data
print("==> Preparing data..")
if args.net == "vit_timm":
    size = 384
else:
    size = imsize

transform_train = transforms.Compose(
    [
        transforms.RandomCrop(32, padding=4),
        transforms.Resize(size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]
)

transform_test = transforms.Compose(
    [
        transforms.Resize(size),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]
)

# Add RandAugment with N, M(hyperparameter)
if aug:
    N = 2
    M = 14
    transform_train.transforms.insert(0, RandAugment(N, M))

# Prepare dataset
trainset = torchvision.datasets.CIFAR10(
    root="./data", train=True, download=True, transform=transform_train
)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=bs, shuffle=True, num_workers=8
)

testset = torchvision.datasets.CIFAR10(
    root="./data", train=False, download=True, transform=transform_test
)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=100, shuffle=False, num_workers=8
)

classes = (
    "plane",
    "car",
    "bird",
    "cat",
    "deer",
    "dog",
    "frog",
    "horse",
    "ship",
    "truck",
)

# Model factory..
print("==> Building model..")
# net = VGG('VGG19')
# Instantiate the base ViT model
net = OriginalViT(
    image_size=size,
    patch_size=args.patch,
    num_classes=10,  # Original number of classes
    dim=int(args.dimhead),
    depth=6,
    heads=8,
    mlp_dim=512,
    dropout=0.1,
    emb_dropout=0.1,
)

if args.training_stage == "finetune":
    print("==> Setting up for fine-tuning stage (Frozen Base)...")
    # Define the bottleneck head
    bottleneck_dim = 2
    num_classes = 10
    # Ensure original_dim matches the instantiated model's mlp_head input dimension
    # Assuming the original head is nn.Sequential(nn.LayerNorm(dim), nn.Linear(dim, num_classes))
    # or similar, we need the feature dimension *before* the original final linear layer.
    # Let's access it assuming the structure from the original ViT code.
    # If the original head is just nn.Linear, use net.mlp_head.in_features
    # If it's a Sequential as defined before, access the Linear layer within it.
    try:
        # Assuming original mlp_head is like nn.Sequential(nn.LayerNorm(dim), nn.Linear(dim, num_classes))
        original_dim = net.mlp_head[1].in_features
    except (TypeError, IndexError):
        # Fallback if it's just a single Linear layer or different structure
        print(
            "Warning: Assuming net.mlp_head is a single Linear layer for getting in_features."
        )
        original_dim = (
            net.mlp_head.in_features
        )  # Check if this is correct for the OriginalViT structure

    bottleneck_head = nn.Sequential(
        nn.LayerNorm(original_dim),
        nn.Linear(original_dim, bottleneck_dim),
        # nn.ReLU(), # Optional activation
        nn.Linear(bottleneck_dim, num_classes),
    )

    # Load pre-trained weights
    if args.pretrained_ckpt is None:
        raise ValueError("Must provide --pretrained_ckpt for finetuning stage")
    if not os.path.exists(args.pretrained_ckpt):
        raise FileNotFoundError(
            f"Pretrained checkpoint not found: {args.pretrained_ckpt}"
        )
    print(f"==> Loading pre-trained weights from {args.pretrained_ckpt}")
    # Explicitly set weights_only=False as the checkpoint contains more than just weights (like args)
    checkpoint = torch.load(
        args.pretrained_ckpt, map_location=device, weights_only=False
    )  # Ensure loading to correct device

    # Decide which state_dict key to use ('model' or 'net')
    model_key = "model" if "model" in checkpoint else "net"
    if model_key not in checkpoint:
        raise KeyError(
            f"Could not find model state_dict ('model' or 'net') in checkpoint: {args.pretrained_ckpt}"
        )

    # Load state dict, being careful about the mlp_head mismatch
    # `strict=False` ignores keys that don't match (like the old mlp_head)
    missing_keys, unexpected_keys = net.load_state_dict(
        checkpoint[model_key], strict=False
    )
    print(f"Missing keys: {missing_keys}")
    print(f"Unexpected keys: {unexpected_keys}")
    print("==> Pre-trained weights loaded.")

    # Freeze all parameters
    print("==> Freezing encoder weights...")
    for param in net.parameters():
        param.requires_grad = False

    # Replace the head and ensure its parameters are trainable (should be by default)
    net.mlp_head = bottleneck_head
    print("==> Replaced MLP head with bottleneck version.")
    # Ensure new head parameters require grad (usually true by default)
    for param in net.mlp_head.parameters():
        param.requires_grad = True

elif args.training_stage == "pretrain":
    print("==> Setting up for pre-training stage..")
    # The standard ViT head is already in place
    pass  # No changes needed for the standard head

if "cuda" in device:
    print(device)
    # Move model to device *before* DataParallel
    net = net.to(device)
    if args.dp:
        print("using data parallel")
        net = torch.nn.DataParallel(net)  # make parallel
        cudnn.benchmark = True
else:
    # Ensure model is on CPU if CUDA not available
    net = net.to(device)

criterion = nn.CrossEntropyLoss()

# Setup optimizer based on training stage
print(f"==> Setting up optimizer for stage: {args.training_stage}")
if args.training_stage == "finetune":
    # Optimize only the parameters of the new head
    # Ensure requires_grad is correctly set before filtering
    trainable_params = filter(lambda p: p.requires_grad, net.parameters())
    num_trainable = sum(p.numel() for p in trainable_params)
    print(f"==> Optimizing ONLY the new head parameters ({num_trainable} parameters).")
    # It's essential that trainable_params is not empty. Re-assign it to a list
    # to check and prevent errors if filter returns nothing.
    trainable_params = list(filter(lambda p: p.requires_grad, net.parameters()))
    if not trainable_params:
        raise RuntimeError(
            "No trainable parameters found for the optimizer in finetune stage. Check model setup."
        )
elif args.training_stage == "pretrain" or args.training_stage == "finetune_full":
    # Optimize all parameters for pretrain or finetune_full
    all_params = list(net.parameters())  # Convert generator to list
    trainable_params = all_params  # Use the list for optimizer
    num_trainable = sum(p.numel() for p in trainable_params if p.requires_grad)
    print(f"==> Optimizing ALL model parameters ({num_trainable} parameters).")
else:
    # Should not happen due to choices in argparse
    raise ValueError(f"Unknown training stage: {args.training_stage}")


if args.opt == "adam":
    optimizer = optim.Adam(trainable_params, lr=args.lr)
elif args.opt == "sgd":
    optimizer = optim.SGD(
        trainable_params, lr=args.lr, momentum=0.9, weight_decay=5e-4
    )  # Added momentum/decay like common settings


scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.n_epochs)

scaler = torch.cuda.amp.GradScaler(enabled=use_amp)


def train(epoch):
    print("\nEpoch: %d" % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        # Train with amp
        with torch.cuda.amp.autocast(enabled=use_amp):
            outputs = net(inputs)
            loss = criterion(outputs, targets)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        progress_bar(
            batch_idx,
            len(trainloader),
            "Loss: %.3f | Acc: %.3f%% (%d/%d)"
            % (train_loss / (batch_idx + 1), 100.0 * correct / total, correct, total),
        )
    return train_loss / (batch_idx + 1), 100.0 * correct / total


def test(epoch):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            progress_bar(
                batch_idx,
                len(testloader),
                "Loss: %.3f | Acc: %.3f%% (%d/%d)"
                % (
                    test_loss / (batch_idx + 1),
                    100.0 * correct / total,
                    correct,
                    total,
                ),
            )

    # Save checkpoint.
    acc = 100.0 * correct / total
    if acc > best_acc:
        print("Saving..")
        # Use model.module.state_dict() if using DataParallel
        model_state = net.module.state_dict() if args.dp else net.state_dict()
        state = {
            # Use 'model' key consistent with timm/huggingface conventions
            "model": model_state,
            "acc": acc,
            "epoch": epoch,
            "optimizer": optimizer.state_dict(),
            "scaler": scaler.state_dict() if use_amp else None,
            "args": args,  # Save args for reproducibility
        }
        if not os.path.isdir("checkpoint"):
            os.mkdir("checkpoint")
        # Include training stage and run_name in checkpoint name
        name_part = f"{args.run_name}_" if args.run_name else ""
        ckpt_filename = f"{name_part}{args.net}-{args.training_stage}-p{args.patch}-seed{args.seed}-ckpt.t7"
        torch.save(state, os.path.join("./checkpoint", ckpt_filename))
        best_acc = acc

    os.makedirs("log", exist_ok=True)
    # Log file naming based on stage and run_name
    name_part = f"{args.run_name}_" if args.run_name else ""
    log_filename = f"{name_part}log_{args.net}_{args.training_stage}_p{args.patch}_seed{args.seed}.txt"
    content = (
        time.ctime()
        + " "
        + f'Epoch {epoch}, lr: {optimizer.param_groups[0]["lr"]:.7f}, val loss: {test_loss/(batch_idx+1):.5f}, acc: {acc:.5f}'
    )
    print(content)
    with open(os.path.join("log", log_filename), "a") as appender:
        appender.write(content + "\n")
    # Return val_loss avg for the epoch and accuracy
    return test_loss / (batch_idx + 1), acc


def main():
    list_loss = []
    list_acc = []

    if usewandb:
        wandb.watch(net)

    best_acc = 0.0  # Ensure best_acc is initialized if not resuming

    for epoch in range(
        start_epoch, start_epoch + args.n_epochs
    ):  # Ensure correct number of epochs run
        start = time.time()
        trainloss, trainacc = train(epoch)
        val_loss, acc = test(epoch)

        scheduler.step()  # Step each epoch

        list_loss.append(val_loss)
        list_acc.append(acc)

        # Log training..
        if usewandb:
            wandb.log(
                {
                    "epoch": epoch,
                    "train_loss": trainloss,
                    "train_acc": trainacc,
                    "val_loss": val_loss,
                    "val_acc": acc,
                    "best_val_acc": best_acc,  # Log best accuracy seen so far
                    "lr": optimizer.param_groups[0]["lr"],
                    "epoch_time": time.time() - start,
                }
            )

        # Write out csv.. (Consider changing log format from overwriting lists to appending rows)
        # Log file naming based on stage and run_name
        name_part = f"{args.run_name}_" if args.run_name else ""
        log_filename_csv = f"{name_part}log_{args.net}_{args.training_stage}_p{args.patch}_seed{args.seed}_summary.csv"
        # Check if file exists to write header
        file_exists = os.path.isfile(os.path.join("log", log_filename_csv))
        with open(os.path.join("log", log_filename_csv), "a", newline="") as f:
            writer = csv.writer(f)
            if not file_exists:  # Write header only once
                writer.writerow(
                    ["epoch", "train_loss", "train_acc", "val_loss", "val_acc", "lr"]
                )
            writer.writerow(
                [
                    epoch,
                    trainloss,
                    trainacc,
                    val_loss,
                    acc,
                    optimizer.param_groups[0]["lr"],
                ]
            )

    # writeout wandb
    if usewandb:
        # Save the final model checkpoint to wandb
        name_part = f"{args.run_name}_" if args.run_name else ""
        final_ckpt_filename = f"{name_part}{args.net}-{args.training_stage}-p{args.patch}-seed{args.seed}-final-ckpt.t7"
        # Use model.module.state_dict() if using DataParallel
        model_state = net.module.state_dict() if args.dp else net.state_dict()
        final_state = {
            "model": model_state,
            "acc": acc,
            "epoch": epoch,
            "optimizer": optimizer.state_dict(),
            "scaler": scaler.state_dict() if use_amp else None,
            "args": args,
        }
        torch.save(final_state, final_ckpt_filename)
        wandb.save(final_ckpt_filename)
        wandb.finish()  # End the wandb run explicitly


if __name__ == "__main__":
    import multiprocessing

    multiprocessing.freeze_support()  # For Windows support
    main()  # Call the main training loop
