"""
Train CIFAR10 with PyTorch.
Adds Jacobian Regularization via facebookresearch/jacobian_regularizer
"""

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "5"
# os.environ["CUDA_VISIBLE_DEVICES"] = "3,4,5"

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import numpy as np
import random
import pickle
import matplotlib.pyplot as plt

import torchvision
import torchvision.transforms as transforms
from torchvision.transforms import v2

# Jacobian Regularization
from jacobian import JacobianReg

import argparse
import time

from model import get_model
from data import get_data, make_planeloader
from utils import (
    get_loss_function,
    get_scheduler,
    get_random_images,
    produce_plot,
    get_noisy_images,
    AttackPGD,
)
from evaluation import train, test, test_on_trainset, decision_boundary, test_on_adv
from options import options
from utils import (
    simple_lapsed_time,
    adjust_learning_rate,
    adjust_lambda_reg_linear,
    adjust_lambda_reg_sin,
    smooth_batched_along_resolution,
)
from tqdm import tqdm

from check_gpu import print_used_gpus
from set_seed import set_seed, set_seed_detailed


def train_with_jacobian(
    args,
    net,
    trainloader,
    optimizer,
    criterion,
    reg,
    lambda_jr,
    device,
):
    net.train()
    train_loss_total = 0.0     
    reg_loss_total = 0.0       
    correct = 0
    total = 0

    for batch_idx, (inputs, targets) in enumerate(tqdm(trainloader)):
        inputs, targets = inputs.to(device), targets.to(device)

        
        inputs.requires_grad_(True)

        optimizer.zero_grad()

        outputs = net(inputs)
        loss_super = criterion(outputs, targets)
        R = reg(inputs, outputs)         
        loss = loss_super + lambda_jr * R

        loss.backward()
        optimizer.step()

        
        train_loss_total += loss_super.item()
        reg_loss_total += R.item()

        
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        if args.dryrun:
            break

    train_acc = 100.0 * correct / total
    train_loss = train_loss_total / len(trainloader)
    reg_loss = reg_loss_total / len(trainloader)
    return train_acc, train_loss, reg_loss


def plot_training_curves(
    train_accs,
    test_accs,
    train_losses,
    test_losses,
    reg_terms,
    save_net_name,
    save_path="training_plots",
):
    """Plot training curves."""
    epochs = range(1, len(train_accs) + 1)

    plt.figure(figsize=(12, 5))

    # Accuracy subplot
    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_accs, "b-", label="Training Accuracy")
    plt.plot(epochs, test_accs, "r-", label="Test Accuracy")
    plt.title("Training and Test Accuracy")
    plt.xlabel("Epochs")
    plt.ylabel("Accuracy")
    plt.legend()
    plt.grid(True)

    # Loss subplot
    plt.subplot(1, 2, 2)
    plt.plot(epochs, train_losses, "g-", label="Training Loss")
    plt.plot(epochs, test_losses, "orange", label="Test Loss")


    if any(reg_terms):
        plt.plot(epochs, reg_terms, "m--", label="Jacobian Reg Term")

    plt.title("Training and Test Loss")
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.legend()
    plt.grid(True)

    plt.tight_layout()
    os.makedirs(save_path, exist_ok=True)
    plt.savefig(os.path.join(save_path, f"{save_net_name}.png"))
    plt.close()

    # Save raw data
    data = {
        "train_accs": train_accs,
        "test_accs": test_accs,
        "train_losses": train_losses,
        "test_losses": test_losses,
        "reg_terms": reg_terms,
    }
    with open(os.path.join(save_path, f"{save_net_name}_data.pkl"), "wb") as f:
        pickle.dump(data, f)


if __name__ == "__main__":
    args = options().parse_args()
    set_seed(args.set_seed)

    print("Args:")
    for k, v in vars(args).items():
        print("\t{}: {}".format(k, v))

    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Create directory for plots
    os.makedirs("training_plots", exist_ok=True)

    print_used_gpus()

    # Init logs
    train_accs = []
    test_accs = []
    train_losses = []
    test_losses = []
    reg_terms = []  
    num_classes = 10  # CIFAR-10

    # Data
    trainloader, testloader = get_data(args)

    # Raw trainset (no augmentation) if needed
    transform_test = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize(
                (0.4914, 0.4822, 0.4465),
                (0.2023, 0.1994, 0.2010),
            ),
        ]
    )
    raw_trainset = torchvision.datasets.CIFAR10(
        root="~/data",
        train=True,
        download=True,
        transform=transform_test,
    )
    raw_trainloader = torch.utils.data.DataLoader(
        raw_trainset,
        batch_size=args.bs,
        shuffle=True,
        num_workers=16,
    )

    aug = args.use_data_aug
    use_train_loader = trainloader if aug else raw_trainloader

    set_seed(args.set_seed)

    # Model
    net = get_model(args, device)

    test_acc, predicted = test(args, net, testloader, device, 0)
    print("scratch prediction ", test_acc)

    # Loss
    criterion = get_loss_function(args)

    # Jacobian Regularization
    lambda_jr = args.lambda_reg
    print(f"Using Jacobian Regularization, lambda_JR = {lambda_jr}")
    reg = JacobianReg()  

    # Optimizer
    if args.opt.lower() == "sgd":
        optimizer = optim.SGD(
            net.parameters(),
            lr=args.lr,
            momentum=0.9,
            weight_decay=args.weight_decay,
        )

    elif args.opt.lower() == "adam":
        optimizer = torch.optim.Adam(
            net.parameters(),
            lr=args.lr,
            weight_decay=args.weight_decay,
        )

    elif args.opt.lower() == "adamw":
        optimizer = torch.optim.AdamW(
            net.parameters(),
            lr=args.lr,
            weight_decay=args.weight_decay,
        )
    else:
        raise ValueError(f"Unknown optimizer: {args.opt}")

    print("Training the network or loading the network")

    start = time.time()
    best_acc = 0.0

    if args.load_net is None:
        for epoch in range(args.epochs):
            lr = adjust_learning_rate(optimizer, epoch + 1, args)

            # Train with Jacobian Regularization (no MixUp, no SAM)
            train_acc, train_loss, reg_loss = train_with_jacobian(
                args,
                net,
                use_train_loader,
                optimizer,
                criterion,
                reg,
                lambda_jr,
                device,
            )
            train_accs.append(train_acc)
            train_losses.append(train_loss)
            reg_terms.append(reg_loss)

            # Test acc
            test_acc, predicted = test(args, net, testloader, device, epoch)
            test_accs.append(test_acc)

            # Test loss
            net.eval()
            test_loss = 0.0
            with torch.no_grad():
                for data, target in testloader:
                    data, target = data.to(device), target.to(device)
                    output = net(data)
                    test_loss += criterion(output, target).item()
            test_loss /= len(testloader)
            test_losses.append(test_loss)
            net.train()

            print(
                f"EPOCH: {epoch}/{args.epochs}, LR: {lr:.6f}, "
                f"Train acc: {train_acc:.2f}, Test acc: {test_acc:.2f}, "
                f"Train loss: {train_loss:.5f}, Test loss: {test_loss:.5f}, "
                f"Jacobian reg: {reg_loss:.5f}"
            )

            if args.dryrun:
                break

            # Plot every 5 epochs
            if epoch % 5 == 0:
                plot_training_curves(
                    train_accs,
                    test_accs,
                    train_losses,
                    test_losses,
                    reg_terms,
                    save_net_name=args.save_net,
                    save_path="training_plots",
                )

            model_path = f"saved_models/wd_reg/{str(args.set_seed)}/{args.save_net}"
            if test_acc > best_acc:
                print(f"The best epoch is: {epoch}")
                os.makedirs(model_path, exist_ok=True)
                print(f"{model_path}/{args.save_net}.pth")
                best_acc = test_acc
                # Save model
                # if torch.cuda.device_count() > 1 and isinstance(net, torch.nn.DataParallel):
                #     torch.save(net.module.state_dict(), f"{model_path}/{args.save_net}.pth")
                # else:
                #     torch.save(net.state_dict(), f"{model_path}/{args.save_net}.pth")

        # Final plot
        plot_training_curves(
            train_accs,
            test_accs,
            train_losses,
            test_losses,
            reg_terms,
            save_net_name=args.save_net,
            save_path="training_plots",
        )

    else:
        # Load model
        if isinstance(net, torch.nn.DataParallel):
            net = net.module
        net.load_state_dict(torch.load(args.load_net))

    end = time.time()
    simple_lapsed_time("Time taken to train the model", end - start)
