import numpy as np
import torch, torchvision
from torchvision import datasets, transforms, models
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import *
from PIL import Image
import random
import copy
import pickle
import argparse
import sys
from torch.autograd import Variable, Function
from sklearn.model_selection import StratifiedKFold, train_test_split, ShuffleSplit
from torchvision.transforms import ToTensor
from torch.utils.data import Dataset, Subset, DataLoader, TensorDataset, random_split, ConcatDataset
import h5py
from sklearn.metrics import confusion_matrix

import datetime, time
date_time = datetime.datetime.utcnow().isoformat().replace(":", "")

parser = argparse.ArgumentParser(description='Curriculum Adversarial Training')
parser.add_argument('--dataset', type=str, default='mnist', choices=['mnist', 'cifar10'])
parser.add_argument('--arch', type=str, default='cnn', choices=['cnn', 'resnet', 'linear', 'twolayernet', 'net'])
parser.add_argument('--opt', type=str, default='adam', choices=['adam', 'sgd'])
parser.add_argument('--alpha', type=float, default=0.3)
parser.add_argument('--T', type=int, default=100)
parser.add_argument('--K', type=int, default=3)
parser.add_argument('--steps', type=int, default=10, help='number of pgd steps')
parser.add_argument('--norm', type=str, default='linf', choices=['linf', 'l2'])
parser.add_argument('--seed', default=0, type=int, help='seed')
args = parser.parse_args()


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if using multi-GPU

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Optional: for DataLoader workers
    def seed_worker(worker_id):
        worker_seed = seed + worker_id
        np.random.seed(worker_seed)
        random.seed(worker_seed)

    return seed_worker
    
seed = args.seed
seed_worker = set_seed(seed)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

class linear(nn.Module):
    def __init__(self, d=784):
        super(linear, self).__init__()
        self.net = nn.Linear(d, 10, bias = True)
    def forward(self, x):
        x = x.contiguous()
        batch_size = x.size(0)
        x = x.view(batch_size, -1)
        output = self.net(x)
        return output

class twolayernet(nn.Module):
    def __init__(self, d=784, width=100):
        super(twolayernet, self).__init__()
        self.fc1 = nn.Linear(d, width, bias = True)
        self.fc2 = nn.Linear(width, 10, bias = True)
    def forward(self, x):
        batch_size = x.size(0)
        x = x.view(batch_size, -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

class cnn(nn.Module):
    def __init__(self, width=100):
        super(cnn, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.fc1 = nn.Linear(320, width)
        self.fc2 = nn.Linear(width, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


class net(nn.Module):
    def __init__(self):
        super(net, self).__init__()
        # First convolutional layer: 1 input channel, 32 output channels, 5x5 kernel
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=5, padding=2)  # 'SAME' padding
        # Second convolutional layer: 32 input, 64 output, 5x5 kernel
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, padding=2)
        # Fully connected layer: from 7x7x64 to 1024
        self.fc1 = nn.Linear(7 * 7 * 64, 1024)
        # Output layer: from 1024 to 10
        self.fc2 = nn.Linear(1024, 10)

    def forward(self, x):
        # Input shape: (batch_size, 784), reshape to (batch_size, 1, 28, 28)
        if x.dim() == 2:
            x = x.view(-1, 1, 28, 28)
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, kernel_size=2, stride=2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, kernel_size=2, stride=2)
        x = x.view(x.size(0), -1)  # flatten
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x
        
# === PGD Adversarial Attack ===
def pgd_attack(model, images, labels, alpha, iters, norm='linf'):
    model.eval()
    if alpha==0:
        return images.detach()
    images = images.clone().detach().to(images.device)
    labels = labels.to(images.device)
    delta = torch.zeros_like(images, requires_grad=True)
    atk_lr = alpha / iters * 2
    for _ in range(iters):
        outputs = model(images + delta)
        loss = F.cross_entropy(outputs, labels)
#        loss = nn.CrossEntropyLoss()(outputs, labels)
        loss.backward()
        grad = delta.grad.detach()
        if norm == 'linf':
            delta.data = (delta + atk_lr * grad.sign()).clamp(-alpha, alpha)
        elif norm == 'l2':
            grad_norm = torch.norm(grad.view(grad.size(0), -1), dim=1).view(-1, 1, 1, 1) + 1e-8
            delta.data += atk_lr * grad / grad_norm
            delta_norm = torch.norm(delta.view(delta.size(0), -1), dim=1).view(-1, 1, 1, 1)
            delta.data = delta * torch.clamp(alpha / delta_norm, max=1.0)
        delta.grad.zero_()
        delta.data = torch.clamp(images + delta.data, 0, 1) - images
        delta = delta.detach().requires_grad_()

    adv_images = torch.clamp(images + delta, 0, 1)
    return adv_images.detach()


# === Training Function ===
def train(model, opt, train_loader, val_loader, T=100, lr=1e-3, prev_state_dict=None, reg_lambda=0.0, alpha=0.3, atk_steps=10, norm='linf'):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    if opt == 'sgd':
        optimizer = optim.SGD(model.parameters(), lr = lr, momentum = 0.9, weight_decay=1e-4)
    elif opt =='adam':
        optimizer = optim.Adam(model.parameters(), lr = lr, weight_decay=1e-4)
    
    criterion = nn.CrossEntropyLoss()

    if prev_state_dict:
        model.load_state_dict(prev_state_dict)
        prev_weights = {k: v.clone().detach() for k, v in prev_state_dict.items()}

    best_val_acc = -100.0
    best_model_state = None
    for epoch in range(T):
        model.train()
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            adv_images = pgd_attack(model, images, labels, alpha=alpha, iters=atk_steps, norm=norm)

            optimizer.zero_grad()
            outputs = model(adv_images)
            loss = criterion(outputs, labels)

            if prev_state_dict:
                l2_reg = sum((torch.norm(p - prev_weights[n]) ** 2) for n, p in model.named_parameters())
                loss += reg_lambda * l2_reg

            loss.backward()
            optimizer.step()

        _, val_acc = evaluate_under_attack(model, val_loader, ALPHA, atk_steps, norm)
        if (epoch+1)%10==0:
            print("epoch:", epoch, "train loss:", loss.item() , "adv val_acc:", val_acc)
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model_state = {k: v.clone().detach() for k, v in model.state_dict().items()}

    return best_model_state, best_val_acc


# === Evaluation Function ===
def evaluate_under_attack(model, test_loader, alpha, steps, norm='linf'):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()
    correct, base_correct, total = 0, 0, 0
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        adv_images = pgd_attack(model, images, labels, alpha=alpha, iters=steps, norm=norm)
        
        base_outputs = model(images)
        outputs = model(adv_images)
        
        _, base_predicted = torch.max(base_outputs.data, 1)
        _, predicted = torch.max(outputs.data, 1)
        
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        base_correct += (base_predicted == labels).sum().item()
    return base_correct / total, correct / total



T = args.T
K = args.K
ALPHA = args.alpha
norm = args.norm
if args.steps==0:
    alpha=0

BATCH_SIZE = 128
VAL_RATIO = 0.2
L2_REG_LIST = [1e-5,1e-4,1e-3,1e-2]
LR_LIST = [1e-5,1e-4,1e-3,1e-2]



name = 'cat_'+str(args.dataset)+'_'+str(args.opt)+'_'+str(args.arch)+'_T'+str(T)+'_K'+str(args.K)+'_norm'+str(args.norm)+'_alpha'+str(args.alpha)+'_steps'+str(args.steps)+'_'+str(seed)
log_filename = 'adv/'+name+'.txt'
log = open(log_filename, 'w')
sys.stdout = log


    
if args.dataset == 'mnist':
    transform = transforms.Compose([transforms.ToTensor()])
    dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST(root="./data", train=False, download=True, transform=transform)
    if args.arch == 'cnn':
        model = cnn()
    elif args.arch == 'linear':
        model = linear()
    elif args.arch == 'twolayernet':
        model = twolayernet()
    elif args.arch == 'net':
        model = net()
    else:
        models.resnet18(num_classes=10)
        model.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=False)
        model.maxpool = nn.Identity()
elif args.dataset == 'cifar10':
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
    test_dataset = datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)
    model = models.resnet18(num_classes=10) if args.arch == 'resnet' else SimpleCNN()



train_len = int((1-VAL_RATIO) * len(dataset))
val_len = len(dataset) - train_len
train_dataset, val_dataset = random_split(dataset, [train_len, val_len])

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)


best_state_dict = None
for stage in range(1, K + 1):
    stage_alpha = ALPHA * stage / K
    print(f"Stage {stage}: Training with PGD alpha={stage_alpha:.4f}, norm={norm}")
    
    if stage == 1:
        L2_REG = [0]
    else:
        L2_REG = L2_REG_LIST
        
    best_val_acc = 0
    for l2_reg in L2_REG:
        for lr in LR_LIST:
            if best_state_dict:
                model.load_state_dict(best_state_dict)
            print("l2_reg:", l2_reg, "lr:", lr)
            prev_state_dict, val_acc = train(
                model, args.opt, train_loader, val_loader, T=T, lr=lr,
                prev_state_dict=best_state_dict,
                reg_lambda=l2_reg,
                alpha=stage_alpha,
                atk_steps=args.steps,
                norm=norm
            )
            print("val acc:", val_acc)
            print("\n")
            
            if val_acc>best_val_acc:
                best_val_acc = val_acc
                prev_best_state = prev_state_dict
                chosen_lr = lr
                chosen_reg = l2_reg
                
    best_state_dict = prev_best_state
    print("chosen lr:", chosen_lr, "chosen_reg:", chosen_reg, "best rob val_acc", best_val_acc)
    print("\n")

model.load_state_dict(best_state_dict)
nat_acc, rob_acc = evaluate_under_attack(model, test_loader, alpha=args.alpha, steps=args.steps, norm=args.norm)
print("test nat_acc:", nat_acc, "test rob_acc:", rob_acc)

model_path ='adv/ckpt/'+name+'.pth'
torch.save(model.state_dict(), model_path)
