import pickle
import tqdm
import re
import os

import torch
import torch.nn as nn
from torch import Tensor
import numpy as np

import foolbox

import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor

from sklearn import metrics

from convexrobust.model.base_certifiable import BaseCertifiable, Certificate, Norm
from convexrobust.model.base_certifiable import BaseCertifiable, Certificate, Norm
from convexrobust.model.randsmooth_certifiable import RandsmoothCertifiable
from convexrobust.utils import dirs, file_utils, pretty
from convexrobust.utils import torch_utils as TU
from convexrobust.main import balance, simple_train


from typing import Type, Dict, List, Optional
from dataclasses import dataclass


@dataclass
class ModelBlueprint:
    model_class: Type[BaseCertifiable]
    epochs: int
    force_load: bool
    constructor_params: Dict
    training_mode: str = 'lightning' # 'lightning' or 'simple'
    force_load_eval: bool = False


@dataclass
class Result:
    signal: Optional[Tensor]  # Images might be too big to store signal
    target: Tensor
    pred: Tensor
    certificate: Optional[Certificate]  # Only if target is class 0
    empirical_certificate: Optional[Certificate]


BlueprintDict = Dict[str, ModelBlueprint]
ModelDict = Dict[str, BaseCertifiable]
ResultDict = Dict[str, List[Result]]


def create_models(blueprints: BlueprintDict, global_params) -> ModelDict:
    model_root = dirs.out_path(global_params.data, 'model')
    if global_params.train and global_params.clear_models:
        file_utils.create_empty_directory(model_root)

    if global_params.tensorboard:
        TU.launch_tensorboard(dirs.path(model_root), 6006, erase=False)

    def create_model(name: str, blueprint: ModelBlueprint):
        model_path = dirs.path(model_root, name)
        constructor_params = {
            **blueprint.constructor_params, 'datamodule': global_params.datamodule
        }

        checkpoint_dir = dirs.path(model_path, 'checkpoints')
        checkpoint_path = dirs.path(checkpoint_dir, 'model.ckpt')

        if global_params.train and not blueprint.force_load:
            pretty.subsection_print(f'Training classifier {name}')

            file_utils.create_empty_directory(model_path)

            model = blueprint.model_class(**constructor_params).to(TU.device())
            model.epochs_n = blueprint.epochs

            logger = TensorBoardLogger(
                model_root, name, version='tensorboard', default_hp_metric=False)

            file_utils.create_empty_directory(checkpoint_dir)
            checkpoint = ModelCheckpoint(checkpoint_dir, filename='model', monitor='val_loss')
            lr_monitor = LearningRateMonitor(logging_interval='epoch')

            if blueprint.training_mode == 'lightning':
                trainer = pl.Trainer(max_epochs=blueprint.epochs, logger=logger,
                                     gpus=TU.gpu_n(), num_sanity_val_steps=0,
                                     callbacks=[checkpoint, lr_monitor])
                trainer.fit(model, global_params.datamodule)

                if blueprint.epochs == 0:
                    trainer.save_checkpoint(checkpoint_path)

                # Load model with best validation performance, not just most recent
                model = TU.load_model_from_checkpoint(
                    checkpoint_path, blueprint, constructor_params)
            elif blueprint.training_mode == 'simple':
                simple_train.simple_train(model, global_params.datamodule,
                                          num_epochs=blueprint.epochs,
                                          noise=getattr(model, 'noise', None))
                checkpoint = {}
                checkpoint['state_dict'] = model.state_dict()
                torch.save(checkpoint, checkpoint_path)

            assert model.class_balance == 0.0

            if global_params.balance:
                pretty.subsection_print(f'Balancing classifier {name}')
                balance.balance_classifier(model, global_params)
                balance.save_class_balance(model.class_balance.item(), checkpoint_path)
        else:
            pretty.subsection_print(f'Loading classifier {name}')
            model = TU.load_model_from_checkpoint(checkpoint_path, blueprint, constructor_params)

            if global_params.rebalance:
                pretty.subsection_print(f'(Re) balancing classifier {name}')
                model.class_balance.fill_(0.0)
                balance.balance_classifier(model, global_params)
                balance.save_class_balance(model.class_balance.item(), checkpoint_path)

        model = model.to(TU.device())
        return model

    return {name: create_model(name, blueprint) for (name, blueprint) in blueprints.items()}


def eval_models(models: ModelDict, blueprints: BlueprintDict, global_params) -> ResultDict:
    results_path = dirs.out_path(global_params.data, 'results.pkl')

    if global_params.eval:
        if os.path.isfile(results_path):
            pretty.subsection_print('Loading previous results for potential eval skip')
            with open(results_path, 'rb') as f:
                results = pickle.load(f)
        else:
            results: ResultDict = {}

        for (name, model) in models.items():
            if not (name in results and blueprints[name].force_load_eval):
                results[name] = []
            assert not model.training
            print(f'Eval {name} with balance: {model.class_balance.item()}: loading {blueprints[name].force_load_eval}')

        dataloader = global_params.datamodule.test_dataloader()
        for (signal, target) in TU.fetch_dataloader(dataloader, global_params.eval_n, True):
            signal, target = signal.unsqueeze(0), target.unsqueeze(0)
            for (name, model) in models.items():
                if blueprints[name].force_load_eval:
                    continue

                if target == 0:
                    pred, certificate = model.certify(signal, target)
                    if pred != target:
                        certificate = certificate.zero()
                else:
                    pred, certificate = model.predict(signal), None

                assert pred.shape == torch.Size([1])

                if global_params.verify_cert and certificate is not None:
                    verify_radii(model, certificate, signal, target)

                if global_params.empirical_cert:
                    emp_certificate = empirical_certificate(model, signal, target)
                else:
                    emp_certificate = None

                results[name].append(Result(None, target, pred, certificate, emp_certificate))

        pretty.subsection_print('Writing results')
        with open(results_path, 'wb') as f:
            pickle.dump(results, f)
    else:
        pretty.subsection_print('Loading results')
        with open(results_path, 'rb') as f:
            results = pickle.load(f)

    return results


def verify_radii(model, certificate, signal, target):
    if isinstance(model, RandsmoothCertifiable):
        return # Don't verify nondeterministic certificates

    wrap_model = TU.SingleLogitWrapper(model).eval() if model.single_logit else model
    fb_model = foolbox.models.PyTorchModel(wrap_model, bounds=(0, 1))
    attacks = {
        Norm.L1: foolbox.attacks.L1ProjectedGradientDescentAttack(steps=100),
        Norm.L2: foolbox.attacks.L2ProjectedGradientDescentAttack(steps=100),
        Norm.LInf: foolbox.attacks.LinfProjectedGradientDescentAttack(steps=300)
    }

    for norm in [Norm.L1, Norm.L2, Norm.LInf]:
        if certificate.radius[norm] > 0.0:
            attack = attacks[norm]
            advs, advs_clipped, is_adv = attack(fb_model, signal, target, epsilons=[certificate.radius[norm]])
            if is_adv:
                import pdb; pdb.set_trace()


def empirical_certificate(model, signal, target):
    wrap_model = TU.SingleLogitWrapper(model).eval() if model.single_logit else model
    fb_model = foolbox.models.PyTorchModel(wrap_model, bounds=(0, 1))
    attack = foolbox.attacks.LInfFMNAttack()
    advs, _, is_adv = attack(fb_model, signal, target, epsilons=None)
    empirical_certificate = Certificate({
        Norm.L1: (advs - signal).norm(1).item(),
        Norm.L2: (advs - signal).norm(2).item(),
        Norm.LInf: (advs - signal).norm(float('inf')).item()
    })
    return empirical_certificate
