import os
import logging
from typing import Any, Callable, Dict, List, Tuple, Union

import numpy as np

import torch

import settings
import pytorch_models
import caching


# logging_format = "%(asctime)s: %(message)s"
logging_format = "{%(func)s:%(lineno)4s: {%(asctime)s: %(message)s"
# logging_format = "{%(pathname)s:%(lineno)4s} %(asctime)s: %(message)s"
# logging_level = logging.INFO
# logging_level = logging.DEBUG

logging_level = 15
logging.basicConfig(level=logging_level,
                    format=logging_format)

logger = logging.getLogger(__name__)


def safe_int(n: float) -> int:
    intn = int(n)
    assert intn == n
    return intn


def _build_layer_list(model_name: str,
                      hidden_layer_widths: List[int],
                      input_dims: List[int],
                      output_width) -> Dict[str, Any]:
    if model_name == "simple_relunet":
        input_width = np.prod(input_dims)
        include_bias = True
        relu_layers = pytorch_models.build_relu_layers(input_width,
                                                       hidden_layer_widths,
                                                       output_width,
                                                       include_bias)
        layer_list = relu_layers
    elif model_name == "simple_relunet_with_input_conv":
        num_out_channels = 5
        in_shape = input_dims
        kernel_size = (2, 2)
        stride = (1, 1)
        conv2d = pytorch_models.FlatConv2d(in_shape,
                                           num_out_channels,
                                           kernel_size,
                                           stride)
        out_shape = (num_out_channels,
                     safe_int((in_shape[1] - kernel_size[0]) / stride[0] + 1),
                     safe_int((in_shape[2] - kernel_size[1]) / stride[1] + 1))

        input_width = out_shape[0] * out_shape[1] * out_shape[2]

        include_bias = True
        relu_layers = pytorch_models.build_relu_layers(input_width,
                                                       hidden_layer_widths,
                                                       output_width,
                                                       include_bias)
        layer_list = [conv2d] + relu_layers
    elif model_name == "simple_relunet_with_input_conv_and_avgpool":
        num_out_channels = 5

        in_shape = input_dims
        kernel_size = (2, 2)
        stride = (1, 1)
        conv2d = pytorch_models.FlatConv2d(in_shape,
                                           num_out_channels,
                                           kernel_size,
                                           stride)

        out_shape = (num_out_channels,
                     safe_int((in_shape[1] - kernel_size[0]) / stride[0] + 1),
                     safe_int((in_shape[2] - kernel_size[1]) / stride[1] + 1))

        in_shape = out_shape
        avgpool_2by2 = pytorch_models.FlatAvgPool(in_shape=in_shape,
                                                  kernel_size=(2, 2))
        out_shape = (in_shape[0],
                     safe_int(in_shape[1] / 2),
                     safe_int(in_shape[2] / 2))
        input_width = out_shape[0] * out_shape[1] * out_shape[2]

        include_bias = True
        relu_layers = pytorch_models.build_relu_layers(input_width,
                                                       hidden_layer_widths,
                                                       output_width,
                                                       include_bias)
        layer_list = [conv2d, avgpool_2by2] + relu_layers
    elif model_name == "simple_mnist_classifier":
        kernel_size = (5, 5)
        num_out_channels = 5

        in_shape = input_dims
        stride = (1, 1)
        conv2d = pytorch_models.FlatConv2d(in_shape,
                                           num_out_channels,
                                           kernel_size,
                                           stride)

        out_shape = (num_out_channels,
                     safe_int((in_shape[1] - kernel_size[0]) / stride[0] + 1),
                     safe_int((in_shape[2] - kernel_size[1]) / stride[1] + 1))

        in_shape = out_shape
        avgpool_2by2 = pytorch_models.FlatAvgPool(in_shape=in_shape,
                                                  kernel_size=(2, 2))
        out_shape = (in_shape[0], safe_int(in_shape[1] / 2), safe_int(in_shape[2] / 2))
        input_width = out_shape[0] * out_shape[1] * out_shape[2]

        include_bias = True
        relu_layers = pytorch_models.build_relu_layers(input_width,
                                                       hidden_layer_widths,
                                                       output_width,
                                                       include_bias)
        layer_list = [conv2d, avgpool_2by2] + relu_layers

    else:
        raise ValueError("I do not know about this model")
    return layer_list


def build_dnn(
    train_x: np.ndarray,
    train_y: np.ndarray,
    dnn_par: Dict[str, Any]) -> torch.nn.Module:

    log_every_epoch = dnn_par["log_every_epoch"]
    criterion_name = dnn_par["criterion_name"]
    optimizer_name = dnn_par["optimizer_name"]
    batch_size = dnn_par["batch_size"]
    epochs = dnn_par["epochs"]
    optim_kwargs = dnn_par["optim_kwargs"]
    model_name = dnn_par["model_name"]
    hidden_layer_widths = dnn_par["hidden_layer_widths"]
    input_dims = dnn_par["input_dims"]
    output_width = dnn_par["output_width"]

    x_torch = torch.from_numpy(train_x).type(torch.FloatTensor)
    y_torch = torch.from_numpy(train_y).type(torch.LongTensor)
    train_dataset = torch.utils.data.TensorDataset(x_torch, y_torch)

    if optimizer_name == "adam":
        optim = torch.optim.Adam
    elif optimizer_name == "sgd":
        optim = torch.optim.SGD
    if criterion_name == "hinge":
        criterion = torch.nn.MultiMarginLoss()
    elif criterion_name == "cross_entropy":
        criterion = torch.nn.CrossEntropyLoss()

    layer_list = _build_layer_list(model_name,
                                   hidden_layer_widths,
                                   input_dims,
                                   output_width)
    model = pytorch_models.Net(layer_list)

    dataloader = torch.utils.data.DataLoader(train_dataset,
                                             batch_size=batch_size)
    optimizer = optim(model.parameters(), **optim_kwargs)

    losses = np.full((epochs,), np.nan)
    for epoch_idx in range(epochs):
        for batch_idx, (x, y) in enumerate(dataloader):
            y_pred = model.forward(x)
            loss = criterion(y_pred, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        losses[epoch_idx] = loss.item()
        if epoch_idx % log_every_epoch == 0:
            logger.info("Epoch completed {:4d} / {:4d} -- loss = {:4f}".format(epoch_idx, epochs, losses[epoch_idx]))
    return model


def assess_test_accuracy(model: pytorch_models.Net,
                         test_x: np.ndarray,
                         test_y: np.ndarray) -> float:
    test_batch_size = 64  # purely to help with memory usage

    x_torch = torch.from_numpy(test_x).type(torch.FloatTensor)
    y_torch = torch.from_numpy(test_y).type(torch.LongTensor)
    test_dataset = torch.utils.data.TensorDataset(x_torch, y_torch)

    test_dataloader = torch.utils.data.DataLoader(test_dataset,
                                                  batch_size=test_batch_size)
    model.eval()
    correct = 0
    with torch.no_grad():
        for data, target in test_dataloader:
            # data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    accuracy = correct / len(test_dataloader.dataset)
    return accuracy


if __name__ == "__main__":
    pass
