from __future__ import print_function
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import MultiStepLR
import torchattacks

#writer = SummaryWriter('runs/experiment_1')

q_level = 3
number_mediate_layers = 4

load_name = "iclr25_weight_2bits_newtry/vgg16_helmet_4layers_T7_2bits_ouralgo90_1_8_1.pt"
save_name = "iclr25_weight_2bits_newtry/vgg16_helmet_4layers_T7_2bits_ouralgo90_1_8_2_noise.pt"

train_step = [True,False,False]
#train_step = [False,True,False]
#train_step = [False,False,True]

step1_factor = 0.001
#print(step1_factor*10)
step2_factor = 2e-6
cfg = {
    'o' : [128,128,'M',256,256,'M',512,512,'M',(1024,0),'M'],
    'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M']
}


class Quantization(torch.autograd.Function):
    @staticmethod
    def forward(ctx, tensor, constant=1):
        ctx.constant = constant
        new_x = torch.div(torch.floor(torch.mul(tensor, constant)), constant)
        return new_x

    @staticmethod
    def backward(ctx, grad_output):
        #print(grad_output)
        return 100*F.hardtanh(grad_output/100), None 

Quantization_ = Quantization.apply

class Clamp_q_(nn.Module):
    def __init__(self, min=0.0, max=1,q_level = q_level):
        super(Clamp_q_, self).__init__()
        self.min = min
        self.max = max
        self.q_level = q_level

    def forward(self, x):
        x = torch.clamp(x, min=self.min, max=self.max)
        x = Quantization_(x, self.q_level)
        return x

class CalculateLoss(torch.nn.Module):
    def __init__(self, q_level,c_max):
        super(CalculateLoss, self).__init__()
        self.q_level = q_level
        self.max_value=c_max

    def forward(self, x):
        q_level = self.q_level
        c_max = self.max_value
        # Simplifying calculations by combining conditions and using in-place operations where possible
        Safe_zero_mask = (x <= 0)
        Safe_one_mask =  (x >= c_max + 0.5/q_level)

        x_scaled = x * q_level
        k = 2 * torch.round(x_scaled - 0.5 - 1e-5) + 1  # Finds the nearest odd integer to x_scaled
        seq_val = (k * 0.5) / q_level

        # Using torch.where to combine operations and reduce memory usage
        seq_val = torch.where((x >= 0) &(x <= 1/q_level), 0, seq_val)
        seq_val = torch.where(Safe_zero_mask, x, seq_val)
        seq_val = torch.where(Safe_one_mask, x, seq_val)
        x = torch.where((x >= 0) &(x <= 1/q_level), 0.5*x, x)
        # Loss calculation
        act_loss =  torch.sum(torch.pow(torch.abs(x - seq_val),2)/2)

        #act_loss = (0.5 / q_level) ** 2 * torch.mean(torch.pow((torch.abs(x - seq_val) + 1e-10) / (0.5 / q_level), 0.5))

        return act_loss


class AddGaussianNoise(object):
    def __init__(self, mean=0., std=1.):
        self.std = std
        self.mean = mean
        
    def __call__(self, tensor):
        #print(torch.randn(tensor.size()) * self.std + self.mean)
        return tensor + torch.randn(tensor.size()) * self.std + self.mean
    
    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)

class VGG_16(nn.Module):
    def __init__(self, vgg_name, quantize_factor=-1, clamp_max=1.0, bias=True):
        super(VGG_16, self).__init__()
        self.clamp_max = clamp_max
        self.bias = bias
        self.conv1 = nn.Conv2d(3, 64, 3, 1,1, bias=True)
        self.bn1 = nn.BatchNorm2d(64)
    
        self.relu_ = nn.ReLU()
        self.conv_layers = nn.ModuleList()
        self.bn_layers = nn.ModuleList()

        #self.conv_layers.append(nn.Conv2d(3, 64, 3, 1, 1, bias=True))
        #self.bn_layers.append(nn.BatchNorm2d(64))
        
        for _ in range(number_mediate_layers):
            self.conv_layers.append(nn.Conv2d(64, 64, 3, 1, 1, bias=True))
            self.bn_layers.append(nn.BatchNorm2d(64))

        self.conv2 = nn.Conv2d(64, 3, 3, 1,1, bias=True)

        self.features = self._make_layers(cfg[vgg_name])
        self.cal = CalculateLoss(q_level,1)
        self.classifier = nn.Sequential(
            nn.Linear(512 , 512),
            nn.ReLU(),
            nn.Dropout(p=0.1),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Dropout(p=0.1),
            nn.Linear(512, 10),
        )
    
    def forward(self, x):

        x =self.conv1(x)
        x = self.bn1(x)
        #x = self.relu_(x)
        y = self.cal(x)
        x = torch.clamp(x, min=0, max=1)
        x = Quantization_(x,q_level)
        #y=0
        i=0
        for conv, bn in zip(self.conv_layers, self.bn_layers):
            i+=1
            if i== number_mediate_layers:
                x = conv(x)
                x = bn(x)
                #x=self.relu_(x)
                y += self.cal(x)
                x = torch.clamp(x, min=0, max=1)
                x = Quantization_(x,q_level)
            else:
                x = conv(x)
                x = bn(x)
                y += self.cal(x)
                x = torch.clamp(x, min=0, max=1)
                x = Quantization_(x,q_level)

        x =self.conv2(x)

        #accumulated_y = 0
        for layer in self.features:
            x = layer(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x, y

    def _make_layers(self, cfg):
        layers = []
        in_channels = 3

        for x in cfg:
            if x == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                padding = x[1] if isinstance(x, tuple) else 1
                out_channels = x[0] if isinstance(x, tuple) else x
                layers += [nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=padding, bias=self.bias),nn.BatchNorm2d(out_channels),nn.ReLU()]
                in_channels = out_channels

        layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
        return nn.Sequential(*layers)
          

def orthogonality_loss(model, beta):
    reg_loss = 0.0
    for name, param in model.named_parameters():
        if 'weight' in name and len(param.size()) > 1:
            if len(param.size()) == 4:  

                W = param.view(param.size(0), -1)

            else:
                W = param

            WT_W = torch.matmul(W.T, W)  
            I = torch.eye(WT_W.size(0), device=param.device)  

            reg_loss += (WT_W - step1_factor * I).pow(2).sum()  

    return beta / 2.0 * reg_loss
def r_fgsm_attack(image, epsilon, data_grad):
    epsilon = epsilon
    image = image.cuda()
    perturbed_image = image + torch.randn_like(image).sign() * epsilon
    perturbed_image = torch.clamp(perturbed_image, 0, 1)

    sign_data_grad = data_grad.sign()
    perturbed_image = perturbed_image + epsilon * sign_data_grad
    perturbed_image = torch.clamp(perturbed_image, 0, 1)

    return perturbed_image

def train_1(args, model, device, train_loader, optimizer, epoch, noise, use_function, clamp_max, scale_fa):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        
        original_data = data.clone()

        data.requires_grad = True
        output, _ = model(data)
        init_loss = F.cross_entropy(output, target)
        model.zero_grad()
        init_loss.backward()
        data_grad = data.grad.data
        epsilon = torch.rand(1).item() * 8 / 255  
        perturbed_data = r_fgsm_attack(data, epsilon, data_grad)

        combined_data = torch.cat([original_data, perturbed_data], dim=0)
        combined_target = torch.cat([target, target], dim=0)  

        optimizer.zero_grad()
        output, l_2 = model(combined_data)
        
        if train_step[0]:
            loss  = orthogonality_loss(model,2e-4)+F.cross_entropy(output, combined_target) #+l_2*step2_factor/scale_fa
        if train_step[1]:
            loss  = F.cross_entropy(output, combined_target) + l_2*step2_factor/scale_fa
        if train_step[2]:
            loss  = F.cross_entropy(output, combined_target)   


        model.zero_grad()
        loss.backward()
        optimizer.step()

        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.item()))
            #print(orthogonality_loss(model,2e-3),F.cross_entropy(output, target),l_2*2e-8)

            if args.dry_run:
                break

def train(args, model, device, train_loader, optimizer, epoch,noise,use_function,clamp_max,scale_fa):
    noise = 0
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output,l_2 = model(data)
        if train_step[0]:
            loss  = orthogonality_loss(model,2e-5)+F.cross_entropy(output, target) + l_2*step2_factor/scale_fa
        if train_step[1]:
            loss  =F.cross_entropy(output, target) + l_2*step2_factor/scale_fa
        if train_step[2]:
            loss  = F.cross_entropy(output, target)    


        loss.backward()
        optimizer.step()
        
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.item()))
            #print(orthogonality_loss(model,2e-3),F.cross_entropy(output, target),l_2*2e-8)

            if args.dry_run:
                break


def test(model, device, test_loader,noise,use_function,clamp_max):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for i in range(1):
            correct_l = []
            for ti in range(1):
                correct = 0
                for data, target in test_loader:
                    data, target = data.to(device), target.to(device)
                    #onehot = torch.nn.functional.one_hot(target, 10)
                    output,_ = model(data)
                    test_loss += F.cross_entropy(output, target, reduction='sum').item()  # sum up batch loss
                    pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
                    #print(pred.eq(target.view_as(pred)).sum().item())
                    correct += pred.eq(target.view_as(pred)).sum().item()
                correct_l.append(correct)
            correct_l_n = np.array(correct_l)
            mean_val = np.mean(correct_l_n) / 1
            min_val = (np.min(correct_l_n) - np.mean(correct_l_n)) / 1
            max_val = (np.max(correct_l_n) - np.mean(correct_l_n)) / 1
            print(f"{mean_val:.2f}({min_val:.2f}, {max_val:.2f})")

    #print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))
    return correct

def fgsm_attack(image, epsilon, data_grad):
    sign_data_grad = data_grad.sign()
    perturbed_image = image + epsilon * sign_data_grad
    perturbed_image = torch.clamp(perturbed_image, -1, 1)
    return perturbed_image

def test_att(model, device, test_loader, epsilon):
    model.eval()
    correct = 0
    adv_examples = []

    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        data.requires_grad = True

        # Forward pass the data through the model
        output,_ = model(data)
        loss = F.cross_entropy(output, target, reduction='sum')
        model.zero_grad()
        loss.backward()
        data_grad = data.grad.data
        perturbed_data = fgsm_attack(data, epsilon, data_grad)
        output_perturbed,_ = model(perturbed_data)
        pred_perturbed = output_perturbed.argmax(dim=1, keepdim=True)

        correct += pred_perturbed.eq(target.view_as(pred_perturbed)).sum().item()

    # Calculate final accuracy
    final_acc = correct 
    print(f"Epsilon: {epsilon}\tTest Accuracy = {correct} / {len(test_loader.dataset)} = {final_acc:.2f}%")
    return final_acc


def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--batch-size', type=int, default=768, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs', type=int, default=3, metavar='N',
                        help='number of epochs to train (default: 14)')
    parser.add_argument('--lr', type=float, default=1, metavar='LR',
                        help='learning rate (default: 1.0)')
    parser.add_argument('--loss_scale', type=float, default=0, metavar='LR',
                        help='learning rate (default: 1.0)')
    parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
                        help='Learning rate step gamma (default: 0.7)')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--dry-run', action='store_true', default=False,
                        help='quickly check a single pass')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--noise', type=float, default=0, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--use_function', type=str, default='relu', metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--clamp_max', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                        help='how many batches to wait before logging training status')
    parser.add_argument('--save-model', action='store_true', default=False,
                        help='For Saving the current Model')
    parser.add_argument('--T', type=int, default=5, metavar='N',
                        help='SNN time window')
    parser.add_argument('--resume', type=str, default=None, metavar='RESUME',
                        help='Resume model from checkpoint')
                        
    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    #torch.manual_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    kwargs = {'batch_size': args.batch_size}
    if use_cuda:
        kwargs.update({'num_workers': 32,
                       'pin_memory': True,
                       'shuffle': True},
                     )

    transform_train = transforms.Compose([
        transforms.ToTensor(),
        #transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

    transform=transforms.Compose([
        transforms.ToTensor(),
        #transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
    transform_clean=transforms.Compose([
        transforms.ToTensor(),
        #transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
    dataset1 = datasets.CIFAR10('../data', train=True, download=True,
                       transform=transform_train)
    
    for k in range(5):
        for i in range(1):
            transform_train_1 = transforms.Compose([
                transforms.RandomRotation(10),
                transforms.RandomCrop(32, padding = 6),
                transforms.RandomHorizontalFlip(),
                #transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                #transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                AddGaussianNoise(std=0.01),
                #AddQuantization()
            ])

        dataset1 = dataset1+ datasets.CIFAR10('../data', train=True, download=True,
                       transform=transform_train_1)
    
    dataset2 = datasets.CIFAR10('../data', train=False,
                       transform=transform)
    #print(type(dataset1[0][0]))
    train_loader = torch.utils.data.DataLoader(dataset1,**kwargs)
    test_loader = torch.utils.data.DataLoader(dataset2,batch_size=512) 

    model = VGG_16('VGG16', clamp_max=1,bias =True).to(device)
    model.load_state_dict(torch.load(load_name), strict=False)

    optimizer = optim.Adadelta(model.parameters(), lr=args.lr,weight_decay = 0.001)
    scheduler = MultiStepLR(optimizer, milestones=[30, 60, 90], gamma=0.2)    
    ACC = 0

    test(model, device, test_loader, args.noise,args.use_function,args.clamp_max)
    epsilon = 8/255
    acc = test_att(model, device, test_loader, epsilon)
    
    i=1
    for epoch in range(1, args.epochs + 1):
        if epoch < 50:
            i=1
        if epoch > 49 and epoch<70:
            i=10
        if epoch > 70 and epoch<100:
            i=100
        train_1(args, model, device, train_loader, optimizer, epoch, args.noise,args.use_function,args.clamp_max,scale_fa=i)
        ACC_2 = test(model, device, test_loader, args.noise,args.use_function,args.clamp_max)
        epsilon = 8/255 
        ACC_ = test_att(model, device, test_loader, epsilon)
        
        if  ACC_>ACC :
            ACC = ACC_
            torch.save(model.state_dict(),save_name)
        
        scheduler.step()


if __name__ == '__main__':
    main()





