import os
import sys
import math
import random
import time
from datetime import datetime
import argparse
import pickle
import ruamel.yaml as yaml
from tqdm import tqdm
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as transforms

import torch.autograd.profiler as profiler

from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.cm as cm

#import save_nlayer_weights as nl
#from setup_mnist import MNIST
#from setup_cifar import CIFAR
#from utils import generate_data
#from bounds.crown.get_bounds_ours import get_weights_list

import imageio

#from get_bounds_torch import Box, BoundLinear, BoundLinearDelta, BoundActivationDelta, BoundActivation, InnerCROWN
from bound_model_conv import TorchCNN

def debug(text, file="ffs.txt"):
    with open(file, "a") as f:
        f.write(text)
        f.write("\n---\n")

def read_schedule_linear(epoch, epoch_markers, schedule):
    for i in range(len(epoch_markers)-1):
        if epoch_markers[i+1] < epoch:
            continue
        if epoch_markers[i] <= epoch:
            val = ((epoch_markers[i+1] - epoch) * schedule[i] + (epoch - epoch_markers[i]) * schedule[i+1])/(epoch_markers[i+1] - epoch_markers[i])
            return val
    return schedule[-1]

def read_schedule_gradual(epoch, epoch_markers, schedule):
    if epoch < epoch_markers[0]:
        return schedule[0]
    if epoch > epoch_markers[1]:
        return schedule[1]
    scale = 1.0
    if len(schedule) == 3:
        scale = schedule[2]
        return (math.exp(scale * (epoch-epoch_markers[0])/(epoch_markers[1]-epoch_markers[0])) - 1) * (schedule[1] - schedule[0]) / (math.exp(scale) - 1) + schedule[0]
    return (epoch - epoch_markers[0])/(epoch_markers[1] - epoch_markers[0]) * (schedule[1] - schedule[0]) + schedule[0]

def read_schedule(epoch, epoch_markers, schedule):
    #for i in range(100):
    #    print(read_schedule_gradual(i, epoch_markers, schedule))
    return read_schedule_gradual(epoch, epoch_markers, schedule)

def standard_loss_fn(inputs, labels, model, epoch, cfg):
    outputs = model(inputs)
    standard_loss = nn.CrossEntropyLoss()(outputs, labels)
    loss = standard_loss
    return loss, {"standard_loss": standard_loss}

def ibp_loss_fn(inputs, labels, model, epoch, cfg):
    kappa = read_schedule(epoch, cfg["loss"]["epoch_markers"], cfg["loss"]["kappa_schedule"])
    outputs = model(inputs)
    standard_loss = nn.CrossEntropyLoss()(outputs, labels)
    #if kappa < 1e-9:
    #    return loss
    eps = read_schedule(epoch, cfg["loss"]["epoch_markers"], cfg["loss"]["eps_schedule"])
    lb, ub = model.IBP(inputs, eps * torch.ones_like(inputs), None, None)
    lub = ub[:]
    lub[torch.arange(lb.size(0)),labels] = lb[torch.arange(lb.size(0)),labels]
    
    ibp_loss = nn.CrossEntropyLoss()(lub, labels)
    loss = (1-kappa) * standard_loss + kappa * ibp_loss
    return loss, {"standard_loss": standard_loss, "ibp_loss": ibp_loss}

def crown_ibp_loss_fn(inputs, labels, model, epoch, cfg):
    kappa = read_schedule(epoch, cfg["loss"]["epoch_markers"], cfg["loss"]["kappa_schedule"])
    outputs = model(inputs)
    standard_loss = nn.CrossEntropyLoss()(outputs, labels)
    #if kappa < 1e-9:
    #    return loss
    beta = read_schedule(epoch, cfg["loss"]["epoch_markers"], cfg["loss"]["kappa_schedule"])
    eps = read_schedule(epoch, cfg["loss"]["epoch_markers"], cfg["loss"]["eps_schedule"])
    lb, ub = model.IBP(inputs, eps * torch.ones_like(inputs), None, None)
    ibp_lub = ub[:]
    ibp_lub[torch.arange(lb.size(0)),labels] = lb[torch.arange(lb.size(0)),labels]
    ibp_loss = nn.CrossEntropyLoss()(ibp_lub, labels)
    #if beta < 1e-9:
    #    return loss
    lb, ub = model.CROWN_IBP(inputs, eps * torch.ones_like(inputs), None, None)
    crown_lub = ub[:]
    crown_lub[torch.arange(lb.size(0)),labels] = lb[torch.arange(lb.size(0)),labels]
    crown_loss = nn.CrossEntropyLoss()(crown_lub, labels)
    loss = (1-kappa) * standard_loss + kappa * ((1-beta) * ibp_loss + beta * crown_loss)
    return loss, {"standard_loss": standard_loss, "ibp_loss": ibp_loss, "crown_loss": crown_loss}

def proven_ibp_loss_fn(inputs, labels, model, epoch, cfg):
    kappa = read_schedule(epoch, cfg["loss"]["epoch_markers"], cfg["loss"]["kappa_schedule"])
    outputs = model(inputs)
    loss = (1-kappa) * nn.CrossEntropyLoss()(outputs, labels)
    standard_loss = nn.CrossEntropyLoss()(outputs, labels)
    #if kappa < 1e-9:
    #    return loss
    beta = read_schedule(epoch, cfg["loss"]["epoch_markers"], cfg["loss"]["kappa_schedule"])
    eps = read_schedule(epoch, cfg["loss"]["epoch_markers"], cfg["loss"]["eps_schedule"])
    lb, ub = model.IBP(inputs, eps * torch.ones_like(inputs), None, None)
    ibp_lub = ub[:]
    ibp_lub[torch.arange(lb.size(0)),labels] = lb[torch.arange(lb.size(0)),labels]
    ibp_loss = nn.CrossEntropyLoss()(ibp_lub, labels)
    #if beta < 1e-9:
    #    return loss
    lb, ub = model.PROVEN_IBP(inputs, eps * torch.ones_like(inputs), None, None, 1e-5)
    proven_lub = ub[:]
    proven_lub[torch.arange(lb.size(0)),labels] = lb[torch.arange(lb.size(0)),labels]
    proven_loss = nn.CrossEntropyLoss()(proven_lub, labels)
    loss = (1-kappa) * standard_loss + kappa * ((1-beta) * ibp_loss + beta * proven_loss)
    return loss, {"standard_loss": standard_loss, "ibp_loss": ibp_loss, "proven_loss": proven_loss}

def proven_ibp_loss_fn_fixed(inputs, labels, model, epoch, cfg):
    use_zero = True #cfg["name"] == "temp"
    kappa = read_schedule(epoch, cfg["loss"]["epoch_markers"], cfg["loss"]["kappa_schedule"])
    outputs = model(inputs)
    loss = (1-kappa) * nn.CrossEntropyLoss()(outputs, labels)
    standard_loss = nn.CrossEntropyLoss()(outputs, labels)
    #if kappa < 1e-9:
    #    return loss
    beta = read_schedule(epoch, cfg["loss"]["epoch_markers"], cfg["loss"]["kappa_schedule"])
    eps = read_schedule(epoch, cfg["loss"]["epoch_markers"], cfg["loss"]["eps_schedule"])
    lb, ub = model.IBP(inputs, eps * torch.ones_like(inputs), None, None)
    ibp_lub = ub[:]
    ibp_lub[torch.arange(lb.size(0)),labels] = lb[torch.arange(lb.size(0)),labels]
    ibp_loss = nn.CrossEntropyLoss()(ibp_lub, labels)
    #if beta < 1e-9:
    #    return loss
    lb, ub = model.PROVEN_IBP(inputs, eps * torch.ones_like(inputs), None, None, 1e-5, use_zero=use_zero)
    proven_lub = ub[:]
    proven_lub[torch.arange(lb.size(0)),labels] = lb[torch.arange(lb.size(0)),labels]
    proven_loss = nn.CrossEntropyLoss()(proven_lub, labels)
    #loss = (1-kappa) * standard_loss + kappa * ((1-beta) * ibp_loss + beta * proven_loss)
    loss = (1-kappa) * standard_loss + kappa * (nn.CrossEntropyLoss()((1-beta) * ibp_lub + beta * proven_lub, labels))
    return loss, {"standard_loss": standard_loss, "ibp_loss": ibp_loss, "proven_loss": proven_loss}

def proven_ibp_alpha_loss_fn(inputs, labels, model, epoch, cfg):
    kappa = read_schedule(epoch, cfg["loss"]["epoch_markers"], cfg["loss"]["kappa_schedule"])
    outputs = model(inputs)
    loss = (1-kappa) * nn.CrossEntropyLoss()(outputs, labels)
    standard_loss = nn.CrossEntropyLoss()(outputs, labels)
    #if kappa < 1e-9:
    #    return loss
    beta = read_schedule(epoch, cfg["loss"]["epoch_markers"], cfg["loss"]["kappa_schedule"])
    eps = read_schedule(epoch, cfg["loss"]["epoch_markers"], cfg["loss"]["eps_schedule"])
    lb, ub = model.IBP(inputs, eps * torch.ones_like(inputs), None, None)
    ibp_lub = ub[:]
    ibp_lub[torch.arange(lb.size(0)),labels] = lb[torch.arange(lb.size(0)),labels]
    ibp_loss = nn.CrossEntropyLoss()(ibp_lub, labels)
    #if beta < 1e-9:
    #    return loss
    q = model.PROVEN_IBP_alpha(inputs, eps * torch.ones_like(inputs), labels, None)
    q[torch.arange(q.size(0)),labels] = 0.0
    proven_alpha_loss = 1e2 * torch.mean(q)
    loss = (1-kappa) * standard_loss + kappa * ((1-beta) * ibp_loss + beta * proven_alpha_loss)
    return loss, {"standard_loss": standard_loss, "ibp_loss": ibp_loss, "proven_alpha_loss": proven_alpha_loss}

def another_loss_fn(inputs, labels, model, epoch, cfg):
    kappa = read_schedule(epoch, cfg["loss"]["epoch_markers"], cfg["loss"]["kappa_schedule"])
    outputs = model(inputs)
    standard_loss = nn.CrossEntropyLoss()(outputs, labels)
    #if kappa < 1e-9:
    #    return loss
    eps = read_schedule(epoch, cfg["loss"]["epoch_markers"], cfg["loss"]["eps_schedule"])
    lb, ub = model.ANOTHER(inputs, eps * torch.ones_like(inputs), None, None)
    lub = ub[:]
    lub[torch.arange(lb.size(0)),labels] = lb[torch.arange(lb.size(0)),labels]
    another_loss = nn.CrossEntropyLoss()(lub, labels)
    loss = (1-kappa) * standard_loss + kappa * another_loss
    return loss, {"standard_loss": standard_loss, "another_loss": another_loss}

def data_augmentation_loss_fn(inputs, labels, model, epoch, cfg):
    eps = read_schedule(epoch, cfg["loss"]["epoch_markers"], cfg["loss"]["eps_schedule"])
    noise = torch.zeros_like(inputs)
    if "distribution" in cfg["loss"].keys() and cfg["loss"]["distribution"] == "normal":
        noise = eps * torch.randn_like(inputs)
    else:
        noise = eps * (2 * torch.rand_like(inputs) - 1)
    outputs = model(inputs + noise)
    standard_loss = nn.CrossEntropyLoss()(outputs, labels)
    loss = standard_loss
    return loss, {"standard_loss": standard_loss}

def train_epoch(epoch, model, dataloader, optimizer, device, cfg):
    loss_fn = {
        "standard": standard_loss_fn,
        "ibp": ibp_loss_fn,
        "crown_ibp": crown_ibp_loss_fn,
        "proven_ibp": proven_ibp_loss_fn,
        "proven_ibp_fixed": proven_ibp_loss_fn_fixed,
        "proven_ibp_alpha": proven_ibp_alpha_loss_fn,
        "another": another_loss_fn,
        "aug": data_augmentation_loss_fn,
    }
    losses = []
    loss_terms = {}
    model.train()
    for i, batch in enumerate(tqdm(dataloader)):
        inputs, labels = batch
        inputs = inputs.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        with torch.autograd.detect_anomaly():
            loss, terms = loss_fn[cfg["loss"]["type"]](inputs, labels, model, epoch, cfg)
            loss.backward()

        nn.utils.clip_grad_norm_(model.parameters(), 0.1)
        optimizer.step()
        
        losses.append(loss.item())
        for term in terms:
            if term not in loss_terms:
                loss_terms[term] = []
            loss_terms[term].append(terms[term].item())
    
    print("\nepoch", epoch)
    total_loss = sum(losses)/len(losses)
    print("total loss: {:.3f}".format(total_loss))
    avg_terms = {term: sum(loss_terms[term])/len(loss_terms[term]) for term in loss_terms}

    for term in avg_terms:
        print("{} loss: {:.3f}".format(term, avg_terms[term]))

    return total_loss, avg_terms

def evaluate_epoch(epoch, model, dataloader, small_dataloader, device, cfg):
    use_zero = True #cfg["name"] == "temp"
    model.eval()
    total = 0
    robust_total = 0
    totals = {i["name"]: 0 for i in cfg["evaluation"]["types"]}
    totals["standard"] = 0
    errors = {}
    eps = 0
    q_tot = 1e-2
    hist = {
        "avg_uniform": [],
        "avg_normal": [],
        "strict_uniform": [],
        "strict_normal": []
    }
    for i, batch in enumerate(tqdm(dataloader)):
        inputs, labels = batch
        inputs = inputs.to(device)
        labels = labels.to(device)
        #inputs = inputs.view(inputs.size(0), -1)
        #with profiler.profile(with_stack=True, profile_memory=True) as prof:
        outputs = model(inputs)

        incorrect = (labels != torch.argmax(outputs, dim=1))
        totals["standard"] += torch.sum(incorrect).item()

        
        #with profiler.profile(with_stack=True, profile_memory=True) as prof:
        #lb, ub = model.CROWN_IBP(inputs, eps * torch.ones_like(inputs), labels, None)
        #print(prof.key_averages(group_by_stack_n=5).table(sort_by='self_cpu_time_total', row_limit=5))
        for j in cfg["evaluation"]["types"]:
            eval_type = j["name"]
            eps = j["eps"]
            if eval_type in ["ibp", "ubp", "crown_ibp", "proven_ibp", "another"]:
                lb, ub = None, None
                if eval_type == "ibp":
                    lb, ub = model.IBP(inputs, eps * torch.ones_like(inputs), labels, None)
                elif eval_type == "ubp":
                    lb, ub = model.UBP_test(inputs, eps * torch.ones_like(inputs), labels, None)
                elif eval_type == "crown_ibp":
                    lb, ub = model.CROWN_IBP(inputs, eps * torch.ones_like(inputs), labels, None, use_zero=use_zero)
                elif eval_type == "proven_ibp":
                    lb, ub = model.PROVEN_IBP(inputs, eps * torch.ones_like(inputs), labels, None, q_tot, use_zero=use_zero)
                elif eval_type == "another":
                    lb, ub = model.ANOTHER(inputs, eps * torch.ones_like(inputs), labels, None)

                lb[torch.arange(lb.size(0)),labels] += 1e3
                robust_incorrect = (torch.min(lb, dim=1)[0] < -1e-5)
                totals[eval_type] += torch.sum(robust_incorrect).item()

    for i, batch in enumerate(tqdm(small_dataloader)):
        inputs, labels = batch
        inputs = inputs.to(device)
        labels = labels.to(device)
        #inputs = inputs.view(inputs.size(0), -1)
        #with profiler.profile(with_stack=True, profile_memory=True) as prof:
        outputs = model(inputs)
        #print(prof.key_averages(group_by_stack_n=5).table(sort_by='self_cpu_time_total', row_limit=5))
        
        for j in cfg["evaluation"]["types"]:

            eval_type = j["name"]
            eps = j["eps"]
            if eval_type in ["crown", "proven", "proven_gaussian"]:
                lb, ub = None, None
                if eval_type == "crown":
                    lb, ub = model.CROWN(inputs, eps * torch.ones_like(inputs), labels, None, use_zero=use_zero)
                elif eval_type == "proven":
                    lb, ub = model.PROVEN(inputs, eps * torch.ones_like(inputs), labels, None, q_tot, use_zero=use_zero)
                elif eval_type == "proven_gaussian":
                    lb, ub = model.PROVEN_gaussian(inputs, eps * torch.ones_like(inputs), labels, None, q_tot, use_zero=use_zero)

                lb[torch.arange(lb.size(0)),labels] += 1e3
                robust_incorrect = (torch.min(lb, dim=1)[0] < -1e-5)
                totals[eval_type] += torch.sum(robust_incorrect).item()

    
    for i, batch in enumerate(tqdm(small_dataloader)):
        inputs, labels = batch
        inputs = inputs.to(device)
        labels = labels.to(device)

        for j in cfg["evaluation"]["types"]:
            if "sample_num" not in j:
                continue
            eval_type = j["name"]
            eps = j["eps"]
            
            stacked_inputs = inputs.repeat(j["sample_num"], 1, 1, 1)
            stacked_labels = labels.repeat(j["sample_num"])
            #print(inputs.size(), stacked_inputs.size(), labels.size(), stacked_labels.size())

            if eval_type == "avg_normal":
                perturbed_inputs = stacked_inputs + eps * torch.randn_like(stacked_inputs)
                perturbed_outputs = model(perturbed_inputs)
                perturbed_incorrect = (stacked_labels != torch.argmax(perturbed_outputs, dim=1))
                #print(perturbed_incorrect.size(), stacked_labels.size(), perturbed_outputs.size(), torch.argmax(perturbed_outputs, dim=1).size(), "size")
                #print(torch.sum(perturbed_incorrect).item(), j["sample_num"], "what")
                totals[eval_type] += torch.sum(perturbed_incorrect).item()/j["sample_num"]
                #hist[eval_type].append(torch.sum(perturbed_incorrect).item())

            if eval_type == "avg_uniform":
                perturbed_inputs = stacked_inputs + eps * (2 * torch.rand_like(stacked_inputs) - 1)
                perturbed_outputs = model(perturbed_inputs)
                perturbed_incorrect = (stacked_labels != torch.argmax(perturbed_outputs, dim=1))
                totals[eval_type] += torch.sum(perturbed_incorrect).item()/j["sample_num"]
                #hist[eval_type].append(torch.sum(perturbed_incorrect).item())

            if eval_type == "strict_normal":
                perturbed_inputs = stacked_inputs + eps * torch.randn_like(stacked_inputs)
                perturbed_outputs = model(perturbed_inputs)
                perturbed_incorrect = (stacked_labels != torch.argmax(perturbed_outputs, dim=1))
                if torch.sum(perturbed_incorrect).item() != 0:
                    totals[eval_type] += 1
                #hist[eval_type].append(torch.sum(perturbed_incorrect).item())

            if eval_type == "strict_uniform":
                perturbed_inputs = stacked_inputs + eps * (2 * torch.rand_like(stacked_inputs) - 1)
                perturbed_outputs = model(perturbed_inputs)
                perturbed_incorrect = (stacked_labels != torch.argmax(perturbed_outputs, dim=1))
                if torch.sum(perturbed_incorrect).item() != 0:
                    totals[eval_type] += 1
                #hist[eval_type].append(torch.sum(perturbed_incorrect).item())

            if eval_type == "normal":
                perturbed_inputs = stacked_inputs + eps * torch.randn_like(stacked_inputs)
                perturbed_outputs = model(perturbed_inputs)
                perturbed_incorrect = (stacked_labels != torch.argmax(perturbed_outputs, dim=1))
                if torch.sum(perturbed_incorrect).item() > j["tolerance"]:
                    totals[eval_type] += 1
                #hist[eval_type].append(torch.sum(perturbed_incorrect).item())

            if eval_type == "uniform":
                perturbed_inputs = stacked_inputs + eps * (2 * torch.rand_like(stacked_inputs) - 1)
                perturbed_outputs = model(perturbed_inputs)
                perturbed_incorrect = (stacked_labels != torch.argmax(perturbed_outputs, dim=1))
                if torch.sum(perturbed_incorrect).item() > j["tolerance"]:
                    totals[eval_type] += 1
                #hist[eval_type].append(torch.sum(perturbed_incorrect).item())
    
    print("\nepoch", epoch)
    for i in totals:
        total_error = totals[i]
        if i in ["standard", "ibp", "ubp", "crown_ibp", "proven_ibp", "another"]:
            total_error /= len(dataloader.dataset)
        else:
            total_error /= len(small_dataloader.dataset)
        errors[i] = total_error
        print("{} error: {:.3f}".format(i, total_error))

    return errors["standard"], errors

def hm(epoch, model, dataloader, small_dataloader, device, cfg):
    use_zero = False
    q_tot = 1e-2
    for i, batch in enumerate(tqdm(small_dataloader)):
        inputs, labels = batch
        inputs = inputs.to(device)
        labels = labels.to(device)
        print(labels, "LABELS")
        #inputs = inputs.view(inputs.size(0), -1)
        #with profiler.profile(with_stack=True, profile_memory=True) as prof:
        outputs = model(inputs)
        #print(prof.key_averages(group_by_stack_n=5).table(sort_by='self_cpu_time_total', row_limit=5))
        
        for j in cfg["evaluation"]["types"]:

            #eval_type = j["name"]
            eval_type = "proven"
            #eps = j["eps"]
            eps = 8/255
            if eval_type in ["proven"]:
                a_l, b_l, a_u, b_u = None, None, None, None
                if eval_type == "proven":
                    a_l, b_l, a_u, b_u = model.PROVEN_hm(inputs, eps * torch.ones_like(inputs), labels, None, q_tot, use_zero=use_zero)
                    display_hm(a_l, os.path.join("experiments", cfg["name"], "hm_lower.png"))
                    display_hm(a_u, os.path.join("experiments", cfg["name"], "hm_upper.png"))
            break
        break

def display_hm(a, fn):
    print(a.size(), "HMMMM")
    a = a.squeeze().abs()
    a = a[:,0,:28,:28]
    print(a.size(), "HMMMM")
    #print(torch.log(a).min(), torch.log(a).max(), "HELP")
    #ok = np.uint8(cm.Greys((a/a.max()).cpu().detach().numpy()) * 255.0)
    ok = np.uint8(cm.jet((a/(0.5*a.max())).cpu().detach().numpy()) * 255.0)
    #ok = cm.jet(np.linspace(0, 1, (28 * 28 * 10))).reshape(10, 28, 28, 4)#(a/a.max()).cpu().detach().numpy())
    print(ok.shape, "OK")
    data = np.zeros((10 * 30, 28, 3), dtype = np.uint8)
    for i in range(10):
        data[i * 30 + 1:i * 30 + 29,:,:] = ok[i,:,:,:3]
    print(data)
    img = Image.fromarray(data, "RGB")
    img.save(fn)

def train(model, dataloader, train_validation_dataloader, small_train_validation_dataloader, validation_dataloader, small_validation_dataloader, optimizer, scheduler, device, cfg, mode):
    losses = []
    errors = {}
    train_errors = {}
    terms = {}
    checkpoint_dir = os.path.join("./experiments", cfg["name"], "checkpoints")
    for i in tqdm(range(cfg["training"]["epochs"])):
        if mode == "eval":
            losses.append(0)
            model.load_state_dict(torch.load(os.path.join("./experiments", cfg["name"], "checkpoints", "model{:03d}.pth".format(i))))
            standard_error, details = evaluate_epoch(i, model, validation_dataloader, small_validation_dataloader, device, cfg)
            for j in details:
                if j not in errors:
                    errors[j] = []
                errors[j].append(details[j])
            continue
            #break
        loss, loss_terms = train_epoch(i, model, dataloader, optimizer, device, cfg)
        losses.append(loss)
        for j in loss_terms:
            if j not in terms:
                terms[j] = []
            terms[j].append(loss_terms[j])

        #model.load_state_dict(torch.load(os.path.join(checkpoint_dir, "model{:03d}.pth".format(i))))
        standard_error, details = evaluate_epoch(i, model, validation_dataloader, small_validation_dataloader, device, cfg)
        for j in details:
            if j not in errors:
                errors[j] = []
            errors[j].append(details[j])
        scheduler.step()
        #torch.save(model.state_dict(), os.path.join(checkpoint_dir, "model{:03d}.pth".format(i)))
    return losses, {"loss": losses, "error": errors, "train_errors": train_errors, "loss_terms": terms}

def write_results(results, cfg):
    experiment_dir = os.path.join("./experiments", cfg["name"], "results")
    plt.figure()
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.title("Training Loss")
    plt.plot(range(cfg["training"]["epochs"]), results["loss"])
    plt.ylim(bottom = 0.0)
    plt.savefig(os.path.join(experiment_dir, "loss.png"))
    plt.clf()

    train_title = {
        "standard": "Standard",
        "ibp": "IBP",
        "crown": "CROWN",
        "crown_ibp": "CROWN-IBP",
        "proven": "PROVEN",
        "proven_ibp": "PROVEN-IBP",
        "proven_ibp_fixed": "PROVEN-IBP",
        "proven_ibp_alpha": "PROVEN-IBP",
        "another": "Another",
        "aug": "Data Augmentation",
    }

    type_title = {
        "standard": "Standard",
        "ibp": "IBP ($B_{\infty}(x, 0.3)$)",
        "crown": "CROWN",
        "crown_ibp": "CROWN-IBP",
        "proven": "PROVEN ($B_{\infty}(x, 0.3)$)",
        "proven_ibp": "PROVEN-IBP",
        "avg_normal": "Normal Sampling (Averaged Error)",
        "avg_uniform": "Uniform Sampling (Averaged Error)",
        "strict_normal": "Normal Sampling (Strict Error)",
        "strict_uniform": "Uniform Sampling (Strict Error)",
        "normal": "Normal Sampling (1% Error)",
        "uniform": "Uniform Sampling (1% Error)",
        "another": "Another",
        "proven_gaussian": "PROVEN (Gaussian)",
        

        "mnist": "MNIST",
        "cifar": "CIFAR",
    }


    for i in results["loss_terms"]:

        plt.xlabel("Epochs")
        plt.ylabel("Loss")
        plt.title(i + " Error")
        plt.plot(range(cfg["training"]["epochs"]), results["loss_terms"][i])
        plt.ylim(bottom = 0.0, top = 1.0)
        plt.savefig(os.path.join(experiment_dir, i+"_term.png"))
        plt.clf()

    for i in results["error"]:
        #results["error"][i] = results["error"][i][50:] + results["error"][i][:50]

        plt.xlabel("Epochs")
        plt.ylabel("Error")
        plt.title(type_title[i] + " Error")
        plt.plot(range(cfg["training"]["epochs"]), results["error"][i])
        plt.ylim(bottom = 0.0, top = 1.0)
        plt.savefig(os.path.join(experiment_dir, i+"_error.png"))
        plt.clf()

    for i in results["train_errors"]:

        plt.xlabel("Epochs")
        plt.ylabel("Error")
        plt.title(type_title[i] + " Error")
        plt.plot(range(cfg["training"]["epochs"]), results["train_errors"][i])
        plt.ylim(bottom = 0.0, top = 1.0)
        plt.savefig(os.path.join(experiment_dir, i+"train_error.png"))
        plt.clf()

    plt.xlabel("Epochs")
    plt.ylabel("Error")
    plt.title("Training Set Errors of {} Trained {} Model".format(train_title[cfg["loss"]["type"]], type_title[cfg["dataset"]]))

    for i in results["train_errors"]:

        plt.plot(range(cfg["training"]["epochs"]), results["train_errors"][i], label=type_title[i])
        #plt.savefig(os.path.join(experiment_dir, i+"_error.png"))
        #plt.clf()

    plt.ylim(bottom = 0.0, top = 1.0)
    plt.legend()
    plt.savefig(os.path.join(experiment_dir, "overall_train_error.png"))
    plt.clf()

    plt.xlabel("Epochs")
    plt.ylabel("Error")
    plt.title("Validation Set Errors of {} Trained MNIST Model".format(train_title[cfg["loss"]["type"]]))

    for i in results["error"]:
        if i not in ["ibp", "proven", "standard"]:
            continue
        plt.plot(range(cfg["training"]["epochs"]), results["error"][i], label=type_title[i])
        #plt.savefig(os.path.join(experiment_dir, i+"_error.png"))
        #plt.clf()
        plt.ylim(bottom = 0.0, top = 1.0)
        plt.legend()
        plt.savefig(os.path.join(experiment_dir, "overall_error.png"))
    
    
def save_results(results, cfg, filename):
    experiment_dir = os.path.join("./experiments", cfg["name"], "results")
    with open(os.path.join(experiment_dir, filename), "wb") as f:
        pickle.dump(results, f)

def save_gif(config_file):

    if evaluation_results:
        for j in cfg["evaluation"]["types"]:
            eval_type = j["name"]
            if "sample_num" not in j:
                continue
            images = []
            for i in range(cfg["training"]["epochs"]):
                img_name = os.path.join("./experiments", cfg["name"], "temp", "hist_" + eval_type + str(i) + ".png")
                images.append(imageio.imread(img_name))
            imageio.mimsave(os.path.join(experiment_dir, "hist_" + eval_type + ".gif"), images)


def imshow(img):
    print(img.max(), img.min())
    img = img/2 + 0.5
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

#------------------------------------------------------------------

def main(config_file):

    current_time = datetime.now()
    with open(config_file, "r") as f:
        cfg = yaml.safe_load(f)
    print(cfg)
    torch.manual_seed(cfg["seed"])
    experiment_dir = os.path.join("./experiments", cfg["name"])
    if not os.path.exists(experiment_dir):
        os.makedirs(experiment_dir)
    for dir_name in ["checkpoints", "temp", "results"]:
        other_dir = os.path.join("./experiments", cfg["name"], dir_name)
        if not os.path.exists(other_dir):
            os.makedirs(other_dir)

    #with open(os.path.join("./experiments", cfg["name"], cfg["name"] + "_" + str(current_time) + ".yaml"), "w") as f:
    #    yaml.dump(cfg, f)

    device = torch.device("cpu")
    if torch.cuda.is_available() and "device" in cfg:
        device = torch.device(cfg["device"])
    print("USING DEVICE:", device)
    plt.figure()
    #transform = transforms.Compose([
    #    transforms.ToTensor(),
    #    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    #])
    dataset, validation_dataset = None, None
    dataloader, validation_dataloader = None, None

    if cfg["dataset"] == "mnist":
        transform = transforms.Compose([
            transforms.ToTensor(),
            #transforms.Normalize((0.5,), (0.5,))
        ])
        #dataset = torchvision.datasets.CIFAR10(root="./new_data", train=True, download=False, transform=transform)
        dataset = torchvision.datasets.MNIST(root="./new_data", train=True, download=False, transform=transform)
        validation_dataset = torchvision.datasets.MNIST(root="./new_data", train=False, download=False, transform=transform)

    elif cfg["dataset"] == "cifar":
        transform = transforms.Compose([
            transforms.ToTensor(),
            #transforms.Normalize((0.5,), (0.5,))
        ])
        #dataset = torchvision.datasets.CIFAR10(root="./new_data", train=True, download=False, transform=transform)
        dataset = torchvision.datasets.CIFAR10(root="./new_data", train=True, download=False, transform=transform)
        validation_dataset = torchvision.datasets.CIFAR10(root="./new_data", train=False, download=False, transform=transform)

    #small_dataset = torch.utils.data.random_split(dataset, [len(dataset)//10, len(dataset)-len(dataset)//10], generator=torch.Generator().manual_seed(42))[0]
    small_dataset = torch.utils.data.random_split(dataset, [len(dataset)//10, len(dataset)-len(dataset)//10])[0]
    small_validation_dataset = torch.utils.data.random_split(validation_dataset, [len(validation_dataset)//10, len(validation_dataset)-len(validation_dataset)//10])[0]
    print(len(small_dataset), len(dataset), "OK")
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=cfg["training"]["batch_size"], shuffle=True, num_workers=0)
    train_validation_dataloader = torch.utils.data.DataLoader(small_dataset, batch_size=cfg["evaluation"]["batch_size"], shuffle=False, num_workers=0)
    small_train_validation_dataloader = torch.utils.data.DataLoader(small_dataset, batch_size=1, shuffle=False, num_workers=0)
    validation_dataloader = torch.utils.data.DataLoader(validation_dataset, batch_size=cfg["evaluation"]["batch_size"], shuffle=False, num_workers=0)
    small_validation_dataloader = torch.utils.data.DataLoader(small_validation_dataset, batch_size=1, shuffle=False, num_workers=0)
    #classes = ("plane", "car", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck")

    #dataiter = iter(dataloader)
    #images, labels = dataiter.next()
    #print(images.min(), images.max())
    #print(images)
    #imshow(torchvision.utils.make_grid(images))

    if "architecture" not in cfg["model"] or cfg["model"]["architecture"] == "mlp":
        model = TorchMLP(cfg["model"]["dims"], nn.ReLU())
    else:
        model = TorchCNN(cfg["model"]["channels"], cfg["model"]["kernel"], cfg["model"]["stride"], cfg["model"]["padding"], nn.ReLU())
    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=cfg["training"]["lr"])
    gamma = 1.0 if "scheduler" not in cfg["training"] else cfg["training"]["scheduler"]["gamma"]
    milestones = [] if "scheduler" not in cfg["training"] else cfg["training"]["scheduler"]["milestones"]
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones, gamma)

    #model.load_state_dict(torch.load(os.path.join(experiment_dir, "model.pth")))

    _, results = train(model, dataloader, train_validation_dataloader, small_train_validation_dataloader, validation_dataloader, small_validation_dataloader, optimizer, scheduler, device, cfg, args.mode)

    #with open(os.path.join(experiment_dir, "results", "results_2021-05-27 13:26:28.743413.pkl"), "rb") as f:
    #with open(os.path.join(experiment_dir, "results", "results_2021-05-27 16:51:25.222400.pkl"), "rb") as f:
    #    results = pickle.load(f)
    #print(results)

    save_results(results, cfg, "results_" + str(current_time)+".pkl")
    write_results(results, cfg)

    #model.load_state_dict(torch.load(os.path.join(experiment_dir, "model_2021-06-04 04:51:26.194837.pth")))
    ##model.load_state_dict(torch.load(os.path.join(experiment_dir, "model_2021-06-04 11:12:18.011288.pth")))
    #hm(0, model, validation_dataloader, validation_dataloader, device, cfg)

    #torch.save(model.state_dict(), os.path.join(experiment_dir, "model_" + str(current_time) + ".pth"))
    

    #dataiter = iter(validation_dataloader)
    #images, labels = dataiter.next()
    #outputs = torch.argmax(model(images.to(device)), dim=1)
    #print(labels.tolist())
    #print(outputs.tolist())
    #imshow(torchvision.utils.make_grid(images))

if __name__ == "__main__":
    #python torch_train.py --config configs/test8.yaml
    parser = argparse.ArgumentParser(description='compute activation bound for CIFAR and MNIST')
    parser.add_argument('--config', 
                help='config file')
    parser.add_argument('--mode',
                default = "train",
                choices = ["train", "eval"],
                help = "run mode")
    args = parser.parse_args()
    main(args.config)
