"""
Define some utility functions
"""
from copy import deepcopy
import pickle as pk
import os
import numpy as np
import subprocess
import sys

import logging.config
import shutil
import pandas as pd
from bokeh.io import output_file, save, show
from bokeh.plotting import figure
from bokeh.layouts import column

import torch


def save_obj(obj, name, save_dir):
    # Create directories to store the results
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    objfile = save_dir.rstrip("\/") + "/" + name + ".pkl"
    with open(objfile, "wb") as f:
        pk.dump(obj, f, pk.HIGHEST_PROTOCOL)


def load_obj(name, save_dir):
    objfile = save_dir.rstrip("\/") + "/" + name + ".pkl"
    with open(objfile, "rb") as f:
        return pk.load(f)


def save_state(model, acc, args):
    print("==> Saving model ...")
    state = {
        "acc": acc,
        "state_dict": model.state_dict(),
    }
    for key in state["state_dict"].keys():
        if "module" in key:
            state["state_dict"][key.replace("module.", "")] = state["state_dict"].pop(
                key
            )
    torch.save(state, args.save_name)


def create_val_folder(args):
    """
    create folder structure: root/class1/images.png
    only supports Tiny-Imagenet
    """
    assert args.dataset == "TINYIMAGENET200"

    path = os.path.join(
        args.data_path, "val/images"
    )  # path where validation data is present now
    filename = os.path.join(
        args.data_path, "val/val_annotations.txt"
    )  # file where image2class mapping is present
    fp = open(filename, "r")  # open file in read mode
    data = fp.readlines()  # read line by line

    # Create a dictionary with image names as key and corresponding classes as values
    val_img_dict = {}
    for line in data:
        words = line.split("\t")
        val_img_dict[words[0]] = words[1]
    fp.close()

    # Create folder if not present, and move image into proper folder
    for img, folder in val_img_dict.items():
        newpath = os.path.join(path, folder)
        if not os.path.exists(newpath):  # check if folder exists
            os.makedirs(newpath)

        if os.path.exists(
            os.path.join(path, img)
        ):  # Check if image exists in default directory
            os.rename(os.path.join(path, img), os.path.join(newpath, img))


""" following is copied from BNN code: https://github.com/itayhubara/BinaryNet.pytorch
"""


def accuracy(output, target, topk=(1,), avg=False):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
        if avg:
            res.append(correct_k.mul_(100.0 / batch_size))
        else:
            res.append(correct_k)
    return res


def setup_logging(log_file="log.txt"):
    """Setup logging configuration"""
    logging.basicConfig(
        level=logging.DEBUG,
        format="%(asctime)s - %(levelname)s - %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
        filename=log_file,
        filemode="w",
    )
    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    formatter = logging.Formatter("%(message)s")
    console.setFormatter(formatter)
    logging.getLogger("").addHandler(console)


class ResultsLog(object):
    def __init__(self, path="results.csv", plot_path=None):
        self.path = path
        self.plot_path = plot_path or (self.path + ".html")
        self.figures = []
        self.results = None

    def add(self, **kwargs):
        df = pd.DataFrame([kwargs.values()], columns=kwargs.keys())
        if self.results is None:
            self.results = df
        else:
            self.results = self.results.append(df, ignore_index=True)

    def save(self, title="Training Results"):
        if len(self.figures) > 0:
            if os.path.isfile(self.plot_path):
                os.remove(self.plot_path)
            output_file(self.plot_path, title=title)
            plot = column(*self.figures)
            save(plot)
            self.figures = []
        self.results.to_csv(self.path, index=False, index_label=False)

    def load(self, path=None):
        path = path or self.path
        if os.path.isfile(path):
            self.results.read_csv(path)

    def show(self):
        if len(self.figures) > 0:
            plot = column(*self.figures)
            show(plot)

    # def plot(self, *kargs, **kwargs):
    #    line = Line(data=self.results, *kargs, **kwargs)
    #    self.figures.append(line)

    def image(self, *kargs, **kwargs):
        fig = figure()
        fig.image(*kargs, **kwargs)
        self.figures.append(fig)


def save_checkpoint(
    state, is_best, path=".", filename="checkpoint.pth.tar", save_all=False
):
    filename = os.path.join(path, filename)
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, os.path.join(path, "model_best.pth.tar"))
    if save_all:
        shutil.copyfile(
            filename, os.path.join(path, "checkpoint_epoch_%s.pth.tar" % state["epoch"])
        )


def save_model(state, filename):
    print("==> Saving model ...")
    torch.save(state, filename)
