import math
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.distributions.normal import Normal
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from scipy.special import softmax
from sklearn import metrics
from netcal.metrics import ECE, MCE
import time


class Network(nn.Module):
    def __init__(self, network_specs, loss_type, probabilistic=False):
        super(Network, self).__init__()

        self.loss_type = loss_type
        self.probabilistic = probabilistic

        self.theta_shapes = []
        for layer in network_specs['architecture']:
            if len(layer) == 2:
                self.theta_shapes.append([layer[1], layer[0]])
                self.theta_shapes.append([layer[1]])
            else:
                self.theta_shapes.append([layer[2], layer[0], layer[1], layer[1]])
                self.theta_shapes.append([layer[2]])
                pass
        self.activation = network_specs['activation']

        self.tot_params = sum([math.prod(param_shape) for param_shape in self.theta_shapes])

    def get_theta_shape(self):
        return self.theta_shapes

    def init_params(self, theta):
        for theta_i in theta:
            if theta_i.dim() > 1:
                torch.nn.init.xavier_uniform_(theta_i)

    def get_flat_params(self, theta):
        vec_theta = []
        for param in theta:
            vec_theta.append(param.data.view(-1))
        vec_theta = torch.cat(vec_theta)
        return vec_theta

    def get_unflat_params(self, theta):
        theta_params = []
        i = 0
        for param_shape in self.theta_shapes:
            n_params = math.prod(param_shape)
            theta_params.append(theta[i:i+n_params].view(param_shape))
            i += n_params
        return theta_params

    def forward(self, x, theta, c=None):

        n = len(theta)
        h = x if c is None else torch.cat([x, x*c], dim=-1)

        for i in range(0, (n // 2) - 1):
            if theta[2*i].dim() > 2:
                h = self.activation(F.conv2d(h, theta[2*i], bias=theta[2*i+1]))
                h = F.avg_pool2d(h, 2)
            else:
                h = torch.flatten(h, 1) if h.dim() == 4 else h
                h = self.activation(F.linear(h, theta[2*i], bias=theta[2*i+1]))
        out = F.linear(h, theta[n-2], bias=theta[n-1])
        if not self.probabilistic:
            return out

        mu, sigma = torch.split(out, out.shape[-1] // 2, dim=-1)
        return torch.concatenate([mu, F.softplus(sigma) + 0.01], -1)

    def loss_neglikelihood(self, y_pred, y):
        if self.loss_type == 'mse':
            L = (y_pred - y)**2 / 2
        elif self.loss_type == 'NLL':
            N = Normal(*torch.split(y_pred, y_pred.shape[-1] // 2, dim=-1))
            L = - N.log_prob(y)
        elif self.loss_type == 'CE':
            shape = y.shape
            if len(shape) > 1:
                y = y.view(-1)
                y_pred = y_pred.view(-1, y_pred.shape[-1])
            L = F.cross_entropy(y_pred, y.long(), reduction='none')
        return L.sum()



def regression_metrics(model, loader):

    stats_metrics = {}

    x_test, y_test = loader.dataset.x_test, loader.dataset.y_test
    y_map, y_mu, y_std, py = model.posterior(x_test, loader)

    y = y_test.detach().cpu().numpy()[:,0]
    dist_y_pred = py.detach().cpu().numpy()[:, :, 0]

    mse = np.array([metrics.mean_squared_error(y, dist_y_pred[i]) for i in range(dist_y_pred.shape[0])])
    stats_metrics['MSE'] = [np.mean(mse), np.std(mse)]

    mu_py = torch.mean(py, 0)[:, 0]
    var_py = torch.std(py, 0)[:, 0] ** 2
    nll = 0.5 * torch.mean((torch.log(2 * torch.pi * var_py) + (y_test[:, 0] - mu_py) ** 2 / var_py))
    stats_metrics['NLL'] = nll.detach().cpu().item()

    return stats_metrics

def classification_metrics(model, loader):

    stats_metrics = {}

    x_test, y_test = loader.dataset.x_test, loader.dataset.y_test

    start_t = time.time()
    y_map, y_mu, y_std, py = model.posterior(x_test, loader)
    time_elapsed = time.time() - start_t

    y = y_test.detach().cpu().numpy()
    dist_y_pred = torch.mean(torch.softmax(py, -1), 0).detach().cpu().numpy()
    stats_metrics['NLL'] = metrics.log_loss(y, dist_y_pred, labels=np.arange(py.shape[-1]))
    stats_metrics['time'] = time_elapsed

    return stats_metrics





























