import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.loggers import WandbLogger
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import os
import json
import fire
import pickle

from src.CifarMagLeNet import MagLeNet_1,MagLeNet_1_abl_3,MagLeNet_2_grey
from src.CifarLeNet_models import LeNet
from src.utils import fgsm_attack
from data.CIFAR10_test import LitCIFAR
from data.FashionMNIST_test import LitFashionMNIST_canny


class FeatureModel(pl.LightningModule):
    def __init__(self,l_pixel, l_grid, threshold):
        self.l_pixel = l_pixel
        self.l_grid = l_grid
        self.threshold = threshold

    def 

def main(p=0,dev_mode=False,model_version=1):
    data = LitFashionMNIST_canny
    data.setup()
    model = eval(f'MagLeNet_{model_version}(p={p})')
    model_base = LeNet()

    # early_stopping = EarlyStopping(monitor='val_loss')
    wandb_logger = WandbLogger(project='magnitude-efforts',entity="edebrouwer",log_model=False,name = "EdgeDetection")

    trainer = pl.Trainer(max_epochs=1,gpus=1 if torch.cuda.is_available() else 0)
    trainer.fit(model,datamodule=data)

    # wandb_logger = WandbLogger(project='magnitude-efforts',save_dir='/local0/scratch/madamer')
    # early_stopping = EarlyStopping(monitor='val_loss')
    trainer_base = pl.Trainer(max_epochs=35,
                     # callbacks=[EarlyStopping(monitor='val_loss')],
                     # logger=wandb_logger,
                     fast_dev_run=dev_mode,
                     gpus=1 if torch.cuda.is_available() else 0)
    trainer_base.fit(model_base,datamodule=data)

    model.eval()
    model_base.eval()

    test_dataloader = data.test_dataloader(bs=1)
    output = []
    output_base = []
    epsilons = [0., .05, .1, .15, .2, .25, .3]

    for epsilon in epsilons:
        output_base.append(run_attack(model_base,test_dataloader,epsilon))
        output.append(run_attack(model,test_dataloader,epsilon))

    out_file = os.path.join('experiments','adversarial_MNIST',f'experiment_12_v_{model_version}_p_{p}_out.txt')
    with open(out_file,'w') as f:
        f.write('Output pixel:\n')
        for line in output_base:
            json.dump(line,f)
            f.write('\n')
        f.write('Output mag vec:\n')
        for line in output:
            json.dump(line,f)
            f.write('\n')

if __name__=='__main__':
    fire.Fire(main)
