import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import os
import json
import pickle
import numpy as np
import wandb
import fire
import logging

from data.MNIST_test import LitMNIST
from data.KMNIST_test import LitKMNIST
from data.FashionMNIST_test import LitFashionMNIST
from src.MNISTAttacksBase import BaseLeNet
from src.MNISTMagLeNetFinal import MagLeNet
from src.utils import fgsm_attack

def run_attack(model_base,model,test_dataloader,epsilon):
    '''
    This function checks if base model and mag model can classify a picture, then
    creates an adversarial of the base model and tests it on the mag model.
    Use the 'correct' variable to account for all images the model can identify
    '''
    num_adversarials = 0
    correct = 0
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model_base.to(device)
    for data,y in tqdm(test_dataloader):
        data=data.to(device)
        y=y.to(device)
        # Check if the model can classify the datapoint
        out = model.adv_test_step(data)
        y_hat = out.argmax(dim=1,keepdim=False)
        if y_hat.item() != y.item():
            continue

        # Now do the same for the base model
        data.requires_grad = True
        out = model_base.adv_test_step(data)
        y_hat = out.argmax(dim=1,keepdim=False)
        if y_hat.item() != y.item():
            continue

        loss = F.nll_loss(out, y)
        model_base.zero_grad()
        loss.backward()
        data_grad = data.grad
        perturbed_data = fgsm_attack(data.cpu(), epsilon, data_grad.cpu())

        out = model_base.adv_test_step(perturbed_data.to(device))
        y_hat = out.argmax(dim=1,keepdim=False) # get the index of the max log-probability
        if y_hat.item() == y.item(): # only carry on if it's a genuine adversarial item
            continue
        else:
            num_adversarials += 1

        out = model.adv_test_step(perturbed_data.to(device))
        y_hat = out.argmax(dim=1,keepdim=False) # get the index of the max log-probability
        if y_hat.item() == y.item():
            correct += 1

    precision = correct/num_adversarials

    print(f'Epsilon: {epsilon}\t Precision: {correct}/{num_adversarials} = {precision}')
    return [epsilon,correct,num_adversarials,precision]

def test_model(model,test_dataloader):
    '''
    Run a test loop
    '''
    correct = 0
    all = 0
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    for x,y in tqdm(test_dataloader):
        all += y.shape[0]
        x=x.to(device)
        y=y.to(device)
        correct += (model.adv_test_step(x).argmax(dim=1,keepdim=False) == y).cpu().sum()
    return correct/all

def main(dataset='MNIST',levels=100,p=1,power=1,l_grid=1.,l_pixel=1.,hamming=False,dev_mode=False):
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s %(message)s'
    )

    logging.info(f'Dataset: {dataset}')
    logging.info(f'Levels: {levels}')
    logging.info(f'P: {p}')
    logging.info(f'Power: {power}')
    logging.info(f'l_grid: {l_grid}')
    logging.info(f'l_pixel: {l_pixel}')
    logging.info(f'Hamming: {hamming}')

    data_path = 'data'

    data_base = eval(f'Lit{dataset}(val_set=True)')
    data_base.setup()

    # Train the base model if it hasn't been trained already
    if os.path.exists(os.path.join('output',f'LeNet_{dataset}.ckpt')):
        model_base = BaseLeNet.load_from_checkpoint(checkpoint_path=os.path.join('output',f'LeNet_{dataset}.ckpt'))
    else:
        model_base = BaseLeNet()
        checkpoint_callback = ModelCheckpoint(dirpath='output',
                                              filename=f'LeNet_{dataset}',
                                              monitor='val_precision',
                                              mode='max',
                                              save_top_k=1,
                                              every_n_epochs=1)

        trainer = pl.Trainer(max_epochs=30,
                             callbacks=[checkpoint_callback],
                             fast_dev_run=dev_mode,
                             gpus=1 if torch.cuda.is_available() else 0)
        trainer.fit(model_base,datamodule=data_base)
        trainer.test(model_base,datamodule=data_base)

        if not dev_mode:
            print(checkpoint_callback.best_model_score.cpu())
            model_base = model_base.load_from_checkpoint(checkpoint_path=checkpoint_callback.best_model_path)

    # Train the model if it hasn't been trained already
    if os.path.exists(os.path.join('output',f'MagLeNet_{dataset}_levels_{levels}_p_{p}_hamming_{hamming}.ckpt')):
        model = MagLeNet.load_from_checkpoint(checkpoint_path=os.path.join('output',f'MagLeNet_{dataset}_levels_{levels}_p_{p}_hamming_{hamming}.ckpt'))
    else:
        model = MagLeNet(levels=levels,p=p,power=power,l_grid=l_grid,l_pixel=l_pixel,hamming=hamming)
        checkpoint_callback = ModelCheckpoint(dirpath='output',
                                              filename=f'MagLeNet_{dataset}_levels_{levels}_p_{p}_hamming_{hamming}',
                                              monitor='val_precision',
                                              mode='max',
                                              save_top_k=1,
                                              every_n_epochs=1)

        trainer = pl.Trainer(max_epochs=30,
                             callbacks=[checkpoint_callback],
                             fast_dev_run=dev_mode,
                             gpus=1 if torch.cuda.is_available() else 0)
        trainer.fit(model,datamodule=data_base)
        trainer.test(model,datamodule=data_base)

        if not dev_mode:
            print(checkpoint_callback.best_model_score.cpu())
            model = model.load_from_checkpoint(checkpoint_path=checkpoint_callback.best_model_path)

    model.eval()
    model_base.eval()
    test_dataloader = data_base.test_dataloader(bs=1)

    out_file = os.path.join('output',f'base_LeNet_model_MagLeNet_{dataset}_levels_{levels}_p_{p}_hamming_{hamming}.txt')
    f = open(out_file,'w')

    accuracy = test_model(model_base,test_dataloader)
    print(f"Base accuracy LeNet: {accuracy * 100}")
    f.write(f"Base accuracy LeNet: {accuracy * 100}\n")

    accuracy = test_model(model,test_dataloader)
    print(f"Base accuracy MagLeNet: {accuracy * 100}")
    f.write(f"Base accuracy MagLeNet: {accuracy * 100}\n")

    # Run attacks
    epsilons = [.05, .1, .2, .3, .4, .5, .6, .7]
    for e in epsilons:
        _,_,_,accuracy = run_attack(model_base,model,test_dataloader,e)

        print(f"Accuracy on epsilon {e} adversarial test examples: {accuracy * 100}")
        f.write(f"Accuracy on epsilon {e} adversarial test examples: {accuracy * 100}\n")
    f.close()

if __name__=='__main__':
    fire.Fire(main)
