import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import os
import json
import pickle
import numpy as np
import wandb
import fire
import logging

from data.MNIST_test import LitMNIST
from data.KMNIST_test import LitKMNIST
from data.FashionMNIST_test import LitFashionMNIST
from src.MNISTAttacksBase import BaseLeNet
from src.utils import fgsm_attack

def run_attack(model,test_dataloader,epsilon):
    '''
    This function checks if the base model can classify a picture, then
    creates an adversarial of the base model and tests it.
    '''
    num_adversarials = 0
    correct = 0
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    for data,y in tqdm(test_dataloader):
        data=data.to(device)
        y=y.to(device)

        data.requires_grad = True
        out = model.adv_test_step(data)
        y_hat = out.argmax(dim=1,keepdim=False)
        if y_hat.item() != y.item():
            continue

        correct += 1

        loss = F.nll_loss(out, y)
        model.zero_grad()
        loss.backward()
        data_grad = data.grad
        perturbed_data = fgsm_attack(data.cpu(), epsilon, data_grad.cpu())

        out = model.adv_test_step(perturbed_data.to(device))
        y_hat = out.argmax(dim=1,keepdim=False) # get the index of the max log-probability
        if y_hat.item() == y.item(): # only carry on if it's a genuine adversarial item
            continue
        else:
            num_adversarials += 1

    precision = (correct-num_adversarials)/correct

    print(f'Epsilon: {epsilon}\t Precision: {correct-num_adversarials}/{correct} = {precision}')
    return [epsilon,correct,num_adversarials,precision]

def test_model(model,test_dataloader):
    '''
    Run a test loop on the model
    '''
    correct = 0
    all = 0
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    for x,y in tqdm(test_dataloader):
        all += y.shape[0]
        x=x.to(device)
        y=y.to(device)
        correct += (model.adv_test_step(x).argmax(dim=1,keepdim=False) == y).cpu().sum()
    return correct/all

def main(dataset='MNIST',dev_mode=False):
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s %(message)s'
    )

    logging.info(f'Dataset: {dataset}')

    data_path = 'data'

    data_base = eval(f'Lit{dataset}(val_set=True)')
    data_base.setup()


    # Train the base model if it hasn't been trained already
    if os.path.exists(os.path.join('output',f'LeNet_{dataset}.ckpt')):
        model = BaseLeNet.load_from_checkpoint(checkpoint_path=os.path.join('output',f'LeNet_{dataset}.ckpt'))
        model.eval()
    else:
        model = BaseLeNet()
        checkpoint_callback = ModelCheckpoint(dirpath='output',
                                              filename=f'LeNet_{dataset}',
                                              monitor='val_precision',
                                              mode='max',
                                              save_top_k=1,
                                              every_n_epochs=1)

        trainer = pl.Trainer(max_epochs=30,
                             callbacks=[checkpoint_callback],
                             fast_dev_run=dev_mode,
                             gpus=1 if torch.cuda.is_available() else 0)
        trainer.fit(model,datamodule=data_base)
        trainer.test(model,datamodule=data_base)

        if not dev_mode:
            print(checkpoint_callback.best_model_score.cpu())
            model = model.load_from_checkpoint(checkpoint_path=checkpoint_callback.best_model_path)

        model.eval()

    test_dataloader = data_base.test_dataloader(bs=1)

    out_file = os.path.join('output',f'LeNet_{dataset}.txt')
    f = open(out_file,'w')

    accuracy = test_model(model,test_dataloader)
    print(f"Base accuracy LeNet: {accuracy * 100}")
    f.write(f"Base accuracy LeNet: {accuracy * 100}\n")

    # Run attacks
    epsilons = [.05, .1, .2, .3, .4, .5, .6, .7]
    for e in epsilons:
        _,_,_,accuracy = run_attack(model,test_dataloader,e)

        print(f"Accuracy on epsilon {e} adversarial test examples: {accuracy * 100}")
        f.write(f"Accuracy on epsilon {e} adversarial test examples: {accuracy * 100}\n")
    f.close()

if __name__=='__main__':
    fire.Fire(main)
