import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.optim.lr_scheduler import CosineAnnealingLR
from sam.sam import SAM
import time

class FourLayerMLP(nn.Module):
    def __init__(self, input_features=3072, hidden_features=3072, num_classes=10):
        super(FourLayerMLP, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_features, hidden_features),
            nn.GELU(),
            nn.Linear(hidden_features, hidden_features),
            nn.GELU(),
            nn.Linear(hidden_features, hidden_features),
            nn.GELU(),
        )
        self.last_layer = nn.Linear(hidden_features, num_classes)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        out = self.net(x)
        self.penultimate = out
        return self.last_layer(out)
    
    def get_penultimate_reg(self):
        return self.penultimate.norm(p=2).square()/2
    

def train_standard(model, dataloader, optimizer, criterion, scheduler, device, sigma=0.0):
    model.train()
    total_loss, total_correct = 0, 0
    for inputs, targets in dataloader:
        inputs, targets = inputs.to(device), targets.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets) + sigma**2  * model.get_penultimate_reg()
        loss.backward()
        optimizer.step()
        scheduler.step()

        total_loss += loss.item() * inputs.size(0)
        _, preds = outputs.max(1)
        total_correct += preds.eq(targets).sum().item()

    return total_loss / len(dataloader.dataset), total_correct / len(dataloader.dataset)

def train_sam(model, dataloader, optimizer, criterion, scheduler, device):
    model.train()
    total_loss, total_correct = 0, 0
    for inputs, targets in dataloader:
        inputs, targets = inputs.to(device), targets.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.first_step(zero_grad=True)

        outputs = model(inputs)
        criterion(outputs, targets).mean().backward()
        optimizer.second_step(zero_grad=True)

        scheduler.step()

        total_loss += loss.item() * inputs.size(0)
        _, preds = outputs.max(1)
        total_correct += preds.eq(targets).sum().item()

    return total_loss / len(dataloader.dataset), total_correct / len(dataloader.dataset)

def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss, total_correct = 0, 0
    with torch.no_grad():
        for inputs, targets in dataloader:
            inputs, targets = inputs.to(device), targets.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, targets)

            total_loss += loss.item() * inputs.size(0)
            _, preds = outputs.max(1)
            total_correct += preds.eq(targets).sum().item()

    return total_loss / len(dataloader.dataset), total_correct / len(dataloader.dataset)

def print_args(args):
    for arg, value in vars(args).items():
        print(f"{arg}: {value}")

def main(args):
    
    device = "cuda" if torch.cuda.is_available() else "cpu"

    transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=(0.4914, 0.4822, 0.4465),  # CIFAR-10 mean
            std=(0.2470, 0.2435, 0.2616)    # CIFAR-10 std
        ),
    ])
    full_train_set = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
    test_set = datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)

    val_set_size = 10000
    train_set_size = len(full_train_set) - val_set_size
    train_set, val_set = random_split(full_train_set, [train_set_size, val_set_size])

    train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, num_workers=4)
    test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False, num_workers=4)

    model = FourLayerMLP().to(device)

    criterion = nn.CrossEntropyLoss()

    if args.train_mode == "standard":
        optimizer = torch.optim.SGD(
            model.parameters(),
            lr=args.lr,
            weight_decay=args.wd,
            momentum=args.momentum_local
        )
        scheduler = CosineAnnealingLR(optimizer, T_max=args.epochs * len(train_loader))
    else:
        optimizer = SAM(
            model.parameters(),
            torch.optim.SGD,
            lr=args.lr,
            weight_decay=args.wd,
            momentum=args.momentum_local,
            rho=args.rho_sam,
            adaptive=False
        )
        scheduler = CosineAnnealingLR(optimizer.base_optimizer, T_max=args.epochs * len(train_loader))

    val_accs, test_accs = [], []

    for epoch in range(args.epochs):
        start_time = time.time()

        if args.train_mode == "standard":
            train_loss, train_acc = train_standard(model, train_loader, optimizer, criterion, scheduler, device, sigma=args.sigma)
        elif args.train_mode == "sam":
            train_loss, train_acc = train_sam(model, train_loader, optimizer, criterion, scheduler, device)
        
        end_time = time.time()
        last_lr = scheduler.get_last_lr()[0]

        print(f"Epoch {epoch + 1}/{args.epochs}:")
        print(f"  Train Loss = {train_loss:.4f}, Train Accuracy = {train_acc * 100:.2f}% (lr={last_lr:.4f})")



        if (epoch + 1) % args.freq == 0:
            val_loss, val_acc = evaluate(model, val_loader, criterion, device)
            _, test_acc = evaluate(model, test_loader, criterion, device)
            val_accs.append(val_acc * 100)
            test_accs.append(test_acc * 100)

            print(f"  Val Loss   = {val_loss:.4f}, Val Accuracy   = {val_acc * 100:.2f}%")
            print(f"  Test Accuracy  = {test_acc * 100:.2f}%")

        print(f"  Time       = {end_time - start_time:.2f}s")


    best_idx_val = val_accs.index(max(val_accs))
    test_acc_idx_val = test_accs[best_idx_val]
    with open("results_sam_mlp.txt", "a") as f:
        f.write(f"rho: {args.rho_sam}, sigma: {args.sigma} wd: {args.wd} test_acc: {test_acc_idx_val} \n")

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--lr", type=float, default=0.1)
    parser.add_argument("--sigma", type=float, default=0.0)
    parser.add_argument("--wd", type=float, default=0.0)
    parser.add_argument("--momentum_local", type=float, default=0.0)
    parser.add_argument("--rho_sam", type=float, default=0.0)
    parser.add_argument("--train_mode", type=str, default="standard", help="standard or sam optimization", choices=["standard", "sam"])
    parser.add_argument("--batch_size", type=int, default=128)
    parser.add_argument("--epochs", type=int, default=200)
    parser.add_argument("--freq", type=int, default=1, help="Frequency to log validation and test accuracy")
    args = parser.parse_args()
    print_args(args)
    main(args)