import os
import datetime
import numpy as np
import sys
import argparse
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from collections import OrderedDict
import torch.nn.functional as F
import time
import pickle

parser = argparse.ArgumentParser(description='PyTorch CIFAR MART Defense')
parser.add_argument('-d', '--data', default='cifar10', type=str)
parser.add_argument('--batch-size', type=int, default=128, metavar='N',
                    help='input batch size for training (default: 1024)')
parser.add_argument('--test-batch-size', type=int, default=128, metavar='N',
                    help='input batch size for testing (default: 100)')
parser.add_argument('--epochs', type=int, default=100, metavar='N',
                    help='number of epochs to train')
parser.add_argument('--weight-decay', '--wd', default=2e-4,
                    type=float, metavar='W')
parser.add_argument('--lr', type=float, default=0.1, metavar='LR',
                    help='learning rate')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
                    help='SGD momentum')
parser.add_argument('--seed', type=int, default=0, metavar='S',
                    help='random seed (default: 0)')
parser.add_argument('--log-interval', type=int, default=1, metavar='N',
                    help='how many batches to wait before logging training status')
parser.add_argument('--model', default='resnet',
                    help='directory of model for saving checkpoint')
parser.add_argument('--save-freq', '-s', default=50, type=int, metavar='N',
                    help='save frequency')
parser.add_argument('--save-dir', '--save-dir', default="CIFAR10_models_LATEST", type=str)
parser.add_argument('--save_dir_base_model', type=str, default='CIFAR10_models_LATEST')
# parser.add_argument('--base_optimizer', '-optimizer', type=str, default='SGD')
parser.add_argument('--base_epochs', type=int, default=120)
parser.add_argument('--base_lr', type=float, default=0.1)
parser.add_argument('--base_batch_size', type=int, default=128)
parser.add_argument('--base_weight_decay', type=float, default=2e-4)
parser.add_argument('--base_momentum', type=float, default=0.9, metavar='M', help='SGD momentum')
parser.add_argument('--base_optimizer', '-optimizer', type=str, default='SGD')

if __name__=="__main__":
    import path
    folder_path= (path.Path(__file__).abspath()).parent.parent
    sys.path.append(folder_path)
    print(folder_path)
    folder_path = folder_path.parent
    sys.path.append(folder_path)
    from classifier_base import Classifier
    from data.pytorch_datasets import *
    from resnet import ResNet18
else:
    from models.classifier_base import Classifier
    from models.defense import resnet
    from models.defense.resnet import ResNet18

if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

torch.manual_seed(0)
np.random.seed(0)

class NN_CIFAR10(nn.Module, Classifier):
    def __init__(self):
        super().__init__()
        self.model = ResNet18()
    
    def forward(self, inputs):
        return self.model(inputs)
    
    def zero_grad(self):
        return self.model.zero_grad()
    
    def get_loss(self, images, labels, requires_mean=True):
        output = self.model(images)
        if requires_mean:
            lossfn = nn.CrossEntropyLoss(reduction="mean")
        else:
            lossfn = nn.CrossEntropyLoss(reduction="none")
        
        return lossfn(output, labels)
    
    def train_model(self, dataset, args=None, val_ds=None):
        self.model.train()
        lr = args.lr
        batch_size = args.base_batch_size
        weight_decay = args.base_weight_decay
        epochs = args.base_epochs
        momentum = args.base_momentum
        log_interval = 1
        save_freq = args.save_freq
        device = args.device

        print(f"Batch size: {batch_size}, lr:{lr}, momentum:{momentum}, epochs:{epochs}, wt. decay: {weight_decay}")

        os.makedirs(args.save_dir_base_model, exist_ok=True)
        # generator = torch.Generator(device)
        args.data = "cifar10"
        _, test_ds = get_dataset(args)
        dl = DataLoader(dataset, batch_size=batch_size, shuffle=False)#, generator=generator)
        val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False)#, generator=generator)
        test_dl = DataLoader(test_ds, batch_size=args.batch_size, shuffle=True)
        if args.base_optimizer=="SGD":
            self.optimizer = optim.SGD(self.model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
        elif args.base_optimizer=="Adam":
            self.optimizer = optim.Adam(self.model.parameters(), lr=lr)

        for epoch in range(1, epochs+1):
            self.model.train()
            start = time.time()
            if args.base_optimizer=="SGD":
                self.adjust_learning_rate(self.optimizer, epoch, lr)
            print(f"Epoch {epoch} of {epochs}...")

            track_loss = 0

            for idx, (images, labels) in enumerate(dl):
                if (idx+1)%50==0:
                    print(f"\t Batch {idx+1} of {len(dl)} done")
                images = images.to(device)
                labels = labels.to(device)

                self.optimizer.zero_grad()

                output = self.model(images)

                loss = F.cross_entropy(output, labels)
                loss.backward()
                track_loss += float(loss)*len(images)

                self.optimizer.step()
            
            track_loss /= len(dataset)
            end = time.time()
            if epoch%log_interval==0:
                print(f"\tEpoch {epoch}: Loss is {track_loss}, time taken is {round(end-start, 3)}")
                with torch.no_grad():
                    self.model.eval()
                    correct = 0
                    total = len(val_ds)
                    for idx, (x, y) in enumerate(val_dl):
                        log = self.evaluate(x, y)
                        correct += log["accuracy"]*x.shape[0]
                    print(f"\tAccuracy on validation set is {(correct/total)}")
            
            save_freq = 50
            # print(f"epoch%save_freq=={epoch%save_freq}, where epoch={epoch} and save_freq={save_freq}")
            if epoch%save_freq==0 or epoch>=95:
                checkpoint = {
                    'epoch' : epoch,
                    'optimizer' : self.optimizer.state_dict(),
                    'state_dict' : self.model.state_dict(),
                    'lr' : self.curr_lr
                }
                save_path = os.path.join(args.save_dir_base_model, f"model-{epoch}-checkpoint")
                print(f"\tSaving model checkpoint to {save_path}")
                torch.save(checkpoint, save_path)
            
            correct = 0
            total = 0
            with torch.no_grad():
                self.model.eval()
                for idx, (images, labels) in enumerate(test_dl):
                    cifar_net.model.eval()
                    images = images.to(device)
                    labels = labels.to(device)
                    outputs = cifar_net.model(images)
                    _, predicted = torch.max(outputs, 1)
                    
                    n_examples = labels.size(0)
                    n_correct = (predicted==labels).sum().item()
                    
                    correct += n_correct
                    total += n_examples
                    
                acc = (correct/total)*100.0
                # loss = loss_total/len(test_ds)
            print(f"*****TEST accuracy after epoch {epoch} is {acc}")
    
    def evaluate(self, images, labels):
        self.model.eval()
        images = images.to(device)
        labels = labels.to(device)
        outputs = self.model(images)
        _, predicted = torch.max(outputs, 1)
        
        n_examples = labels.size(0)
        n_correct = (predicted==labels).sum().item()
        
        val_loss = F.cross_entropy(outputs, labels)
        
        return {"accuracy":round((n_correct*100.0/n_examples),2), "loss":val_loss.item()}
        
    def adjust_learning_rate(self, optimizer, epoch, base_lr):
        """decrease the learning rate"""
        lr = base_lr
        if epoch >= 100:
            lr = base_lr * 0.001
            if epoch==100:
                print(f"Changing learning rate to {lr}")
        elif epoch >= 90:
            lr = base_lr * 0.01
            if epoch==90:
                print(f"Changing learning rate to {lr}")
        elif epoch >= 75:
            lr = base_lr * 0.1
            if epoch==75:
                print(f"Changing learning rate to {lr}")
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
        self.curr_lr = lr
    
    def get_current_lr(self):
        return self.curr_lr
    
    def load_model(self, path):
        a = torch.load(path)
        self.model = a
    
    def save_model(self, path):
        torch.save(self.model, path)
        print("Saving model at:", path)

if __name__=="__main__":
    args = parser.parse_args()
    
    if torch.cuda.is_available():
        torch.device('cuda')
    else:
        device = torch.device('cpu')
        
    print('Using device:',device)

    args.device = device
    print(args)
    print("CIFAR10 net training!")
    os.makedirs(args.save_dir, exist_ok=True)
    info_path = args.save_dir+"/info.txt"
    fp = open(info_path, 'w')
    info_dict = vars(args)
    for k,v in info_dict.items():
        fp.write("%s : %s\n" % (k, v))
    fp.close()
    
    log_path = args.save_dir+'/log.txt'
    sys.stdout = open(log_path, 'w', 1)

    _, test_ds = get_dataset(args)
    
    fp = open('dataset_split.pkl', 'rb')
    a = pickle.load(fp)
    train_ds = a['train_ds']
    val_ds = a['val_ds']
    
    print(f"Size of datasets: train {len(train_ds)}, val {len(val_ds)}")
    
    cifar_net = NN_CIFAR10()
    cifar_net.to(device)
    
    cifar_net.train_model(train_ds, args, val_ds)
    print("Training done! Evaluating")
    test_dl = DataLoader(test_ds, batch_size=args.batch_size, shuffle=True)
    correct = 0
    total = 0
    for idx, (images, labels) in enumerate(test_dl):
        cifar_net.model.eval()
        images = images.to(device)
        labels = labels.to(device)
        outputs = cifar_net.model(images)
        _, predicted = torch.max(outputs, 1)
        
        n_examples = labels.size(0)
        n_correct = (predicted==labels).sum().item()
        
        correct += n_correct
        total += n_examples
        
    acc = correct/total
    # loss = loss_total/len(test_ds)
    print(f"Total accuracy on test set is {acc}")


