from torch.utils.data import DataLoader
from torch import nn
import torch
import numpy as np

class NeuralNetwork(nn.Module):
    def __init__(
            self,
            num_features,
            train_mean_X,
            train_mean_y,
            train_std_X,
            train_std_y,
            params=None,
            num_hidden_nodes=50,
            a0=1,
            b0=0.1,
            batch_size=100,
            clamp_precisions=False
        ):
        super().__init__()
        self.num_features = num_features
        self.train_mean_X = train_mean_X
        self.train_mean_y = train_mean_y
        self.train_std_X = train_std_X
        self.train_std_X[self.train_std_X==0] = 1
        self.train_std_y = train_std_y
        self.num_hidden_nodes = num_hidden_nodes
        self.d = self.num_features * self.num_hidden_nodes + self.num_hidden_nodes * 2 + 3
        self.batch_size = batch_size
        self.a0 = a0
        self.b0 = b0
        self.log_gamma = params[-2].clone().detach().requires_grad_(True)
        self.log_lambda = params[-1].clone().detach().requires_grad_(True)
        self.clamp_precisions = clamp_precisions
        # self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(self.num_features, num_hidden_nodes),
            nn.ReLU(),
            nn.Linear(num_hidden_nodes, 1)
        )
        unpacked_params = self.unpack_params(params)
        # with torch.no_grad():
        #     for i, p in enumerate(self.parameters()):
        #         p.copy_(unpacked_params[i])
        for i, p in enumerate(self.parameters()):
            p.data = unpacked_params[i]


    def forward(self, x):
        # x = self.flatten(x)
        # logits = self.linear_relu_stack(x)
        # return logits
        return self.linear_relu_stack(x)
    
    # def score(self, features, labels, data_size):
    def score(self, **kwargs):

        self.zero_grad()

        dataloader = kwargs['dataloader']
        temperature = kwargs.get('temperature', 1.0)
        data_size = len(dataloader.dataset)
        dataloader_iter = iter(dataloader)
        sample = next(dataloader_iter)
        features, labels = sample[0][:,:-1], sample[0][:,-1]
        features, labels = self.normalise(features, labels)

        labels_pred = self(features)

        log_prior_data = (self.a0 - 1) * self.log_gamma - self.b0 * torch.exp(self.log_gamma) + self.log_gamma

        log_lik_data = -0.5 * self.batch_size * (np.log(2*np.pi) - self.log_gamma) \
            - (torch.exp(self.log_gamma)/2) * torch.sum(torch.pow(labels_pred.squeeze() - labels, 2))

        log_prior_w = 0
        for p in self.parameters():
            log_prior_w += torch.pow(p, 2).sum()
        log_prior_w = - (torch.exp(self.log_lambda)/2) * log_prior_w
        log_prior_w += -0.5 * (self.d-2) * (np.log(2*np.pi)-self.log_lambda) \
            + (self.a0-1) * self.log_lambda - self.b0 * torch.exp(self.log_lambda) + self.log_lambda

        # log_posterior = (log_lik_data * data_size / self.batch_size + log_prior_data + log_prior_w)
        log_posterior = (log_lik_data * data_size / self.batch_size) / temperature + log_prior_data + log_prior_w
        
        log_posterior.backward()

        grads = []
        for p in self.parameters():
            grads.append(p.grad.flatten())
        grads.append(self.log_gamma.grad)
        grads.append(self.log_lambda.grad)

        # print(grads)

        return self.pack_params(grads)
    
    def compute_mass(self, data_loader):
        """
        Compute per-parameter mass for this network.
        """
        self.zero_grad()
        _ = self.score(dataloader=data_loader)

        grad_list = []

        for param in self.linear_relu_stack.parameters():
            if param.grad is not None:
                grad_list.append(param.grad.detach().flatten())
            else:
                grad_list.append(torch.zeros_like(param).flatten())

        for hyperparam in [self.log_gamma, self.log_lambda]:
            if hyperparam.grad is not None:
                grad_list.append(hyperparam.grad.detach().flatten())
            else:
                grad_list.append(torch.zeros_like(hyperparam).flatten())

        grad_tensor = torch.cat(grad_list, dim=0)
        mass = grad_tensor.std(dim=0)**2 + 1e-6
        return mass
    
    def log_prob(self, dataloader):
        dataloader_iter = iter(dataloader)
        batch = next(dataloader_iter)[0]
        features, labels = batch[:, :-1], batch[:, -1]
        features, labels = self.normalise(features, labels)
        labels_pred = self(features)

        log_prior_data = (self.a0 - 1) * self.log_gamma - self.b0 * torch.exp(self.log_gamma) + self.log_gamma

        log_lik_data = -0.5 * self.batch_size * (np.log(2*np.pi) - self.log_gamma) \
            - (torch.exp(self.log_gamma)/2) * torch.sum((labels_pred.squeeze() - labels)**2)

        log_prior_w = 0
        for p in self.parameters():
            log_prior_w += torch.pow(p, 2).sum()
        log_prior_w = - (torch.exp(self.log_lambda)/2) * log_prior_w
        log_prior_w += -0.5 * (self.d-2) * (np.log(2*np.pi)-self.log_lambda) \
            + (self.a0-1) * self.log_lambda - self.b0 * torch.exp(self.log_lambda) + self.log_lambda

        log_posterior = log_lik_data * len(dataloader.dataset) / self.batch_size + log_prior_data + log_prior_w
        return log_posterior

    
    def normalise(self, X, y):
        X_normalised = (X-self.train_mean_X) / self.train_std_X
        y_normalised = (y-self.train_mean_y) / self.train_std_y
        return X_normalised, y_normalised
    
    def pack_params(self, params):
        return torch.concat([p.flatten() for p in params])

    def unpack_params(self, params):
        return [
            params[:self.num_features * self.num_hidden_nodes].reshape(self.num_hidden_nodes, self.num_features),
            params[self.num_features * self.num_hidden_nodes:(self.num_features+1) * self.num_hidden_nodes],
            params[(self.num_features+1) * self.num_hidden_nodes:(self.num_features+2) * self.num_hidden_nodes].unsqueeze(0),
            params[-3:-2],
            params[-2:-1],
            params[-1:]
        ]
    
    def update_params(self, grads_packed):
        grads_unpacked = self.unpack_params(grads_packed)
        ix = 0
        for p in self.parameters():
            p.data += grads_unpacked[ix]
            ix += 1
        self.log_gamma.data += grads_unpacked[-2].item()
        self.log_lambda.data += grads_unpacked[-1].item()

        if self.clamp_precisions:
            self.log_gamma.data = torch.clamp(self.log_gamma.data, min=-self.clamp_precisions, max=self.clamp_precisions)
            self.log_lambda.data = torch.clamp(self.log_lambda.data, min=-self.clamp_precisions, max=self.clamp_precisions)

    def set_params(self, params_packed):
        params_unpacked = self.unpack_params(params_packed)
        ix = 0
        for p in self.parameters():
            p.data = params_unpacked[ix].clone()
            ix += 1
        self.log_gamma.data = params_unpacked[-2].clone().squeeze()
        self.log_lambda.data = params_unpacked[-1].clone().squeeze()

        if self.clamp_precisions:
            self.log_gamma.data = torch.clamp(self.log_gamma.data, min=-self.clamp_precisions, max=self.clamp_precisions)
            self.log_lambda.data = torch.clamp(self.log_lambda.data, min=-self.clamp_precisions, max=self.clamp_precisions)

    def evaluate(self, dataset):
        features, labels = dataset[:,:-1][0], dataset[:,-1][0]
        features_normalised, labels_normalised = self.normalise(features, labels)
        self.labels_pred = self(features_normalised) * self.train_std_y + self.train_mean_y
        self.probs = torch.sqrt(torch.exp(self.log_gamma)) \
            * torch.exp( -1 * (torch.pow(self.labels_pred - labels.unsqueeze(1), 2) / 2) * torch.exp(self.log_gamma) ) \
            / torch.sqrt(2*torch.tensor(torch.pi))
        self.ll = torch.log(torch.mean(self.probs))
        self.rmse = torch.sqrt(torch.mean((self.labels_pred - labels)**2))
    
    def evaluate_batch(self, dataset, evaluate_batch_size=1000):
        dataloader = DataLoader(dataset, batch_size=evaluate_batch_size, shuffle=False)
        test_records = dataset.tensors[0].shape[0]
        labels = dataset[:,-1][0]
        self.labels_pred = torch.zeros(test_records, 1)
        self.probs = torch.zeros(test_records, 1)
        for i, batch in enumerate(dataloader):
            features_batch, labels_batch = batch[0][:,:-1], batch[0][:,-1]
            features_batch_normalised, _ = self.normalise(features_batch, labels_batch) 
            labels_pred_batch = self(features_batch_normalised) * self.train_std_y + self.train_mean_y
            self.labels_pred[i*evaluate_batch_size:min((i+1)*evaluate_batch_size,test_records)] = labels_pred_batch.clone().detach()
            log_prob = 0.5 * self.log_gamma \
                - 0.5 * torch.exp(self.log_gamma) * torch.pow(self.labels_pred - labels.unsqueeze(1), 2) \
                - 0.5 * np.log(2 * np.pi)
        self.ll = torch.logsumexp(log_prob, dim=0) - np.log(log_prob.shape[0])
        self.rmse = torch.sqrt(torch.mean((self.labels_pred - labels.unsqueeze(1))**2))

class NeuralNetworkEnsemble():

    def __init__(
            self,
            num_features,
            train_mean_X,
            train_mean_y,
            train_std_X,
            train_std_y,
            particles=None,
            N=None,
            num_hidden_nodes=50,
            a0=1,
            b0=0.1,
            batch_size=100,
            clamp_precisions=False
        ):

        self.N = particles.shape[0]
        self.num_features = num_features
        self.train_mean_X = train_mean_X
        self.train_mean_y = train_mean_y
        self.train_std_X = train_std_X
        self.train_std_y = train_std_y
        self.ensemble = [
            NeuralNetwork(
                self.num_features,
                self.train_mean_X,
                self.train_mean_y,
                self.train_std_X,
                self.train_std_y,
                params,
                num_hidden_nodes,
                a0,
                b0,
                batch_size,
                clamp_precisions=clamp_precisions
            )
            for params in particles
        ]
        self.d = self.ensemble[0].d
        assert self.d == particles.shape[1]

    def score(self, dataloader=None, temperature=1.0):
        score_tensor = torch.empty((self.N, self.d))
        for i in range(self.N):
            score_tensor[i,:] = self.ensemble[i].score(dataloader=dataloader, temperature=temperature)
        return score_tensor
    
    def compute_ensemble_mass(self, data_loader):
        masses = [net.compute_mass(data_loader) for net in self.model]
        self.mass = torch.stack(masses, dim=0).mean(dim=0)
    
    def log_prob(self, dataloader):
        log_probs = torch.empty(self.N)
        for i, net in enumerate(self.ensemble):
            log_probs[i] = net.log_prob(dataloader)
        return log_probs
    
    def update_params(self, grads_packed):
        for i in range(len(self.ensemble)):
            self.ensemble[i].update_params(grads_packed[i,:])

    def set_params(self, params_packed):
        for i in range(len(self.ensemble)):
            self.ensemble[i].set_params(params_packed[i, :])
    
    def normalise(self, X, y):
        X_normalised = (X-self.train_mean_X) / self.train_std_X
        y_normalised = (y-self.train_mean_y) / self.train_std_y
        return X_normalised, y_normalised
    
    def evaluate(self, dataset):

        for network in self.ensemble:
            network.evaluate_batch(dataset)
        self.ll_array = torch.from_numpy(np.fromiter((n.ll for n in self.ensemble), float))
        self.ll = torch.mean(self.ll_array)

        features, labels = dataset[:,:-1][0], dataset[:,-1][0]
        self.labels_pred = torch.empty(self.N, dataset.tensors[0].shape[0])
        for i, n in enumerate(self.ensemble):
            self.labels_pred[i,:] = n.labels_pred.squeeze()
        self.rmse = torch.sqrt(
            torch.mean(
                torch.pow(torch.mean(self.labels_pred, dim=0) - labels, 2)
            )
        )