import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
import random
import numpy as np

import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--save_path', type=str, required=True, help='Path for saving weights')
args = parser.parse_args()

# -------------------------------
# 1) Reproducibility
# -------------------------------
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

# -------------------------------
# 2) Device
# -------------------------------
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# -------------------------------
# 3) Hyperparams
# -------------------------------
num_epochs    = 120
batch_size    = 128
learning_rate = 0.001

# -------------------------------
# 4) Transforms
# -------------------------------
transform_train = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std =[0.229, 0.224, 0.225])
])

transform_test = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std =[0.229, 0.224, 0.225])
])

# -------------------------------
# 5) 5-Layer Residual MLP
# -------------------------------
class MLPResidual3(nn.Module):
    def __init__(self, input_dim=512, hidden_dim=512, output_dim=512):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, output_dim)
        self.relu = nn.ReLU()
        self.residual = nn.Linear(input_dim, output_dim)
    def forward(self, x):
        out = self.relu(self.fc1(x))
        out = self.relu(self.fc2(out))
        out = self.fc3(out)
        return out + self.residual(x)

# -------------------------------
# 6) Train / Eval loops
# -------------------------------
def train(model, loader, criterion, optimizer, epoch):
    model.train()
    total_loss = 0.0
    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch [{epoch}/{num_epochs}]  Train Loss: {total_loss/len(loader):.4f}")


def evaluate(model, loader, criterion):
    model.eval()
    correct, total, loss = 0, 0, 0.0
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss += criterion(outputs, labels).item()
            preds = outputs.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    acc = 100 * correct / total
    print(f"Test Loss: {loss/len(loader):.4f}  Test Acc: {acc:.2f}%")
    return acc

# -------------------------------
# 7) Main
# -------------------------------
def main():
    # Build model
    print("Building model...")
    resnet18 = models.resnet18(weights=None)
    backbone = nn.Sequential(*list(resnet18.children())[:-1])
    model = nn.Sequential(
        backbone,
        nn.Flatten(),
        MLPResidual(512, 512, 512),
        nn.Linear(512, 10)
    ).to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # Data
    print("Loading CIFAR-10...")
    train_ds = datasets.CIFAR10(
        root="./data", train=True,  download=True,  transform=transform_train
    )
    test_ds  = datasets.CIFAR10(
        root="./data", train=False, download=True, transform=transform_test
    )
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,  num_workers=2)
    test_loader  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False, num_workers=2)

    # Train
    print("Starting training...")
    for epoch in range(1, num_epochs + 1):
        train(model, train_loader, criterion, optimizer, epoch)
        if epoch % 10 == 0 or epoch == num_epochs:
            evaluate(model, test_loader, criterion)

    # Save
    save_path = args.save_path
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    torch.save(model.state_dict(), save_path)
    print(f"Model saved to {save_path}")

if __name__ == "__main__":
    main()
