import logging
import pathlib
import pickle
import os
import numpy as np
import torch
import torch.nn as nn
from torch.distributions import Categorical
import torch.optim as optim
from argparse import ArgumentParser
from torchnet import meter
from torch.utils.data import DataLoader
from tqdm import tqdm
# from src.models import *
# from src.noises import *
# from src.smooth import *
# from src.attacks import pgd_attack_smooth
# from src.datasets import get_dataset, get_dim
# from src.lib.dncnn import DnCNN
from convexrobust.model.randsmooth_certifiable import RandsmoothCertifiable
from convexrobust.utils import torch_utils as TU
import warnings

# if __name__ == "__main__":
    # argparser = ArgumentParser()
    # argparser.add_argument("--device", default="cuda", type=str)
    # argparser.add_argument("--lr", default=0.1, type=float)
    # argparser.add_argument("--batch-size", default=64, type=int)
    # argparser.add_argument("--num-workers", default=min(os.cpu_count(), 8), type=int)
    # argparser.add_argument("--num-epochs", default=120, type=int)
    # argparser.add_argument("--print-every", default=20, type=int)
    # argparser.add_argument("--save-every", default=50, type=int)
    # argparser.add_argument("--experiment-name", default="cifar", type=str)
    # argparser.add_argument("--noise", default="Clean", type=str)
    # argparser.add_argument("--sigma", default=None, type=float)
    # argparser.add_argument("--adv", default=2, type=int)
    # argparser.add_argument("--eps", default=0.0, type=float)
    # argparser.add_argument("--k", default=None, type=int)
    # argparser.add_argument("--seed", default=None, type=int)
    # argparser.add_argument("--j", default=None, type=int)
    # argparser.add_argument("--a", default=None, type=int)
    # argparser.add_argument("--lambd", default=None, type=float)
    # argparser.add_argument("--model", default="WideResNet", type=str)
    # argparser.add_argument("--dataset", default="cifar", type=str)
    # argparser.add_argument("--adversarial", action="store_true")
    # argparser.add_argument("--stability", action="store_true")
    # argparser.add_argument("--direct", action="store_true")
    # argparser.add_argument("--save-path", type=str, default=None)
    # argparser.add_argument('--output-dir', type=str, default=os.getenv("PT_OUTPUT_DIR"))
    # argparser.add_argument("--denoiser-path", type=str, default=None)
    # argparser.add_argument("--resume_epoch", default=None, type=int)
    # args = argparser.parse_args()

def simple_train(model, datamodule, lr=0.1, num_epochs=120, print_every=20, save_every=50,
                 noise=None, stability=False, direct=False):
    logging.basicConfig(level=logging.INFO)
    logger = logging.getLogger(__name__)

    model.train()

    train_loader = datamodule.train_dataloader()

    # Generalize for mnist
    optimizer, annealer = model.configure_optimizers()
    optimizer, annealer = optimizer[0], annealer[0]
    # optimizer = optim.SGD(model.parameters(),
                          # lr=lr,
                          # momentum=0.9,
                          # weight_decay=1e-4,
                          # nesterov=True)
    # annealer = optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs)

    loss_meter = meter.AverageValueMeter()
    time_meter = meter.TimeMeter(unit=False)

    # noise = parse_noise_from_args(args, device=args.device, dim=get_dim(args.dataset))
    train_losses = []
    start_epoch = 0
    for epoch in range(start_epoch, num_epochs):

        for i, (x, y) in enumerate(train_loader):

            x, y = x.to(TU.device()), y.to(TU.device())

            if noise is not None:
                if stability:
                    x_tilde = noise.sample(x.view(len(x), -1)).view(x.shape)
                    x = noise.sample(x.view(len(x), -1)).view(x.shape)
                elif not direct:
                    x = noise.sample(x.view(len(x), -1)).view(x.shape)

            if direct:
                loss = -direct_train_log_lik(model, x, y, noise, sample_size=16).mean()
            elif stability:
                pred_x = model.forecast(model.forward(x))
                pred_x_tilde = model.forecast(model.forward(x_tilde))
                loss = -pred_x.log_prob(y) + 6.0 * torch.distributions.kl_divergence(pred_x, pred_x_tilde)
                loss = loss.mean()
            else:
                if isinstance(model, RandsmoothCertifiable):
                    forecast = Categorical(logits=model.forward(x))
                    loss = -forecast.log_prob(y).mean()
                else:
                    loss = model.loss_func(-model.forward(x), y.float())
                    loss = loss + model.extra_loss(x, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loss_meter.add(loss.cpu().data.numpy(), n=1)

            if i % print_every == 0:
                logger.info(f"Epoch: {epoch}\t"
                            f"Itr: {i} / {len(train_loader)}\t"
                            f"Loss: {loss_meter.value()[0]:.2f}\t"
                            f"Mins: {(time_meter.value() / 60):.2f}")
                train_losses.append(loss_meter.value()[0])
                loss_meter.reset()

        # if (epoch + 1) % save_every == 0:
            # save_path = f"{args.output_dir}/{args.experiment_name}/{epoch}/"
            # pathlib.Path(save_path).mkdir(parents=True, exist_ok=True)
            # torch.save(model.state_dict(), f"{save_path}/model_ckpt.torch")

        annealer.step()

    # pathlib.Path(f"{args.output_dir}/{args.experiment_name}").mkdir(parents=True, exist_ok=True)
    # save_path = f"{args.output_dir}/{args.experiment_name}/model_ckpt.torch"
    # torch.save(model.state_dict(), save_path)
    # args_path = f"{args.output_dir}/{args.experiment_name}/args.pkl"
    # pickle.dump(args, open(args_path, "wb"))
    # save_path = f"{args.output_dir}/{args.experiment_name}/losses_train.npy"
    # np.save(save_path, np.array(train_losses))

