import warnings
warnings.simplefilter("ignore", UserWarning)

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import MultivariateNormal
from torch.distributions import Categorical
from torch.utils.data import TensorDataset, DataLoader

from torch.nn.functional import relu
import gpytorch

import itertools
import numpy as np
import math


# from utils.utils import device

# from utils.spectral import SpectralScoreEstimator

device = torch.device('cpu')
if (torch.cuda.is_available()):
    device = torch.device('cuda:0')
    torch.cuda.empty_cache()
    print("Device set to : " + str(torch.cuda.get_device_name(device)))
else:
    print("Device set to : cpu")

def real_imaginary_relu(z):
    return relu(z.real) + 1.j * relu(z.imag)

def weights_init(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_normal_(m.weight.data)
        torch.nn.init.normal_(m.bias.data)

def init_norm_layer(input_dim, norm_layer):
    if norm_layer == "batchnorm":
        return nn.BatchNorm1d(input_dim, eps=0, momentum=None,
                              affine=False, track_running_stats=False)
    elif norm_layer is None:
        return nn.Identity()


def sample_measurement_set(X, num_data):

    # sample from old using geometric distribution
    # p = 2.0 / num_data
    # g = torch.distributions.Geometric(p)
    # n = g.sample()
    # count = 0
    # while n > num_data:
    #     n = g.sample()
    #     count += 1
    #     if count > 10:
    #         n = num_data
    #         break

    n = torch.Tensor([num_data])
    # sample measurement set with size n
    perm = torch.randperm(X.size()[0])
    idx = perm[:n.to(torch.long)]
    measurement_set = X[idx]

    return measurement_set


class ExactGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood, input_dim):
        super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.LinearMean(input_dim)
        # self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.PeriodicKernel() + gpytorch.kernels.RBFKernel())
        # self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel() * gpytorch.kernels.PeriodicKernel())
        # self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.MaternKernel(nu=0.5))
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)


class BatchIndependentMultitaskGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super().__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMean(batch_shape=torch.Size([2]))
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel(batch_shape=torch.Size([2])), batch_shape=torch.Size([2]))

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultitaskMultivariateNormal.from_batch_mvn(gpytorch.distributions.MultivariateNormal(mean_x, covar_x))

class iNN(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dims):
        super(iNN, self).__init__()

        self.input_dim = input_dim

        self.input_layer = nn.Linear(input_dim, hidden_dims[0])
        self.mid_layer = nn.Linear(hidden_dims[0], hidden_dims[1])
        self.output_layer = nn.Linear(hidden_dims[1], output_dim)

        self.activation_fn = torch.tanh

    def forward(self, X):
        X = X.view(-1, self.input_dim)

        X = self.activation_fn(self.input_layer(X))
        X = self.activation_fn(self.mid_layer(X))
        X = self.output_layer(X)

        return X

class ifBDE(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dims, activation_fn, has_continuous_action_space,
                 num_ensemble=5, norm_layer=None):
        """Initialization.
        Args:
            input_dim: int, the size of the input data.
            output_dim: int, the size of the output data.
            hidden_dims: list of int, the list containing the size of
                hidden layers.
            activation_fn: str, the name of activation function to be used
                in the network.
        """
        super(ifBDE, self).__init__()

        self.has_continuous_action_space = has_continuous_action_space

        self.input_dim = input_dim
        self.output_dim = output_dim
        self.hidden_dims = hidden_dims
        self.norm_layer = norm_layer

        # Setup activation function
        options = {'cos': torch.cos, 'tanh': torch.tanh, 'relu': F.relu,
                   'softplus': F.softplus, 'leaky_relu': F.leaky_relu}
        if activation_fn in options:
            self.activation_fn = options[activation_fn]
        else:
            self.activation_fn = activation_fn

        # Initialize layers
        self.num_ensemble = num_ensemble
        self.ensemble_list = nn.ModuleList()

        for i in range(self.num_ensemble):
            self.ensemble_list.append(iNN(input_dim, output_dim, hidden_dims).to(device))

        self.inn_parameter_num = sum(p.numel() for p in self.ensemble_list[0].parameters())

        self.last_layer_wb_size = sum(p.numel() for p in self.ensemble_list[0].output_layer.parameters())

        # self.en1 = iNN(input_dim, output_dim, hidden_dims).to(device)
        # self.en2 = iNN(input_dim, output_dim, hidden_dims).to(device)
        #
        # self.inn_parameter_num = sum(p.numel() for p in self.en1.parameters())
        #
        # self.last_layer_wb_size = sum(p.numel() for p in self.en1.output_layer.parameters())

    def add_noise(self, X, i):

        #TODO: add random noises to X for different ensemble
        # self.noise_std = 0.1
        # X = X + torch.randn_like(X) * self.noise_std

        return X

    def add_noise2(self, X, i):

        #TODO: add random noises to X for different ensemble
        self.noise_std = 1.0
        X = X + torch.randn_like(X) * self.noise_std

        return X

    def forward(self, X):
        """Performs forward pass given input data.
        Args:
            X: torch.tensor, [batch_size, input_dim], the input data.

        Returns:
            torch.tensor, [batch_size, output_dim], the output data.
        """

        output_list = torch.zeros(self.num_ensemble, X.shape[0], self.output_dim).to(device)

        for i in range(self.num_ensemble):
            inn = self.ensemble_list[i]
            tmp = inn(self.add_noise(X, i))
            # tmp = inn(X)
            output_list[i, :, :] = tmp

        return output_list

    def forward_de(self, X):
        """Performs forward pass given input data.
        Args:
            X: torch.tensor, [ensembel_size, batch_size, input_dim], the input data.

        Returns:
            torch.tensor, [ensembel_size, batch_size, output_dim], the output data.
        """

        output_list = torch.zeros(self.num_ensemble, X.shape[1], self.output_dim).to(device)

        for i in range(self.num_ensemble):
            inn = self.ensemble_list[i]
            tmp = inn(self.add_noise(X[i, :, :], i))
            output_list[i, :, :] = tmp

        return output_list

    def forward_de2(self, X):
        """Performs forward pass given input data.
        Args:
            X: torch.tensor, [ensembel_size, batch_size, input_dim], the input data.

        Returns:
            torch.tensor, [ensembel_size, batch_size, output_dim], the output data.
        """

        output_list = torch.zeros(self.num_ensemble, X.shape[1], self.output_dim).to(device)

        for i in range(self.num_ensemble):
            inn = self.ensemble_list[i]
            tmp = inn(self.add_noise2(X[i, :, :], i))
            output_list[i, :, :] = tmp

        return output_list

    def forward_eval(self, X, num_sample):
        """Performs forward pass given input data.
        Args:
            X: torch.tensor, [batch_size, input_dim], the input data.

        Returns:
            torch.tensor, [batch_size, output_dim], the output data.
        """

        output_list = torch.zeros(self.num_ensemble, X.shape[0], self.output_dim).to(device)

        for i in range(self.num_ensemble):
            inn = self.ensemble_list[i]
            tmp = inn(X)
            output_list[i, :, :] = tmp

        return output_list

    def get_ensemble_weights(self, i):
        inn = self.ensemble_list[i]
        param_vec = torch.cat([p.view(-1) for p in inn.parameters()]).to(device)

        return param_vec

    def forward_sum(self, X):
        """Performs forward pass given input data.
        Args:
            X: torch.tensor, [batch_size, input_dim], the input data.

        Returns:
            torch.tensor, [batch_size, output_dim], the output data.
        """

        output_list = torch.zeros(self.num_ensemble, X.shape[0], self.output_dim).to(device)

        for i in range(self.num_ensemble):
            inn = self.ensemble_list[i]
            tmp = inn(X)
            output_list[i, :, :] = tmp

        return torch.mean(output_list, dim=0)

    def forward_sum_cb(self, X):
        """Performs forward pass given input data.
        Args:
            X: torch.tensor, [batch_size, input_dim], the input data.

        Returns:
            torch.tensor, [batch_size, output_dim], the output data.
        """

        output_list = torch.zeros(self.num_ensemble, self.output_dim).to(device)

        for i in range(self.num_ensemble):
            inn = self.ensemble_list[i]
            tmp = inn(X)
            output_list[i, :] = tmp

        return torch.mean(output_list, dim=0).squeeze()

    def forward_wb(self, X):
        """Performs forward pass given input data.
        Args:
            X: torch.tensor, [batch_size, input_dim], the input data.

        Returns:
            torch.tensor, [batch_size, output_dim], the output data.
        """
        batch_size = X.shape[0]
        output_list = torch.zeros(self.num_ensemble, batch_size, self.output_dim).to(device)
        wb_list = torch.zeros(self.num_ensemble, self.inn_parameter_num).to(device)

        for i in range(self.num_ensemble):
            inn = self.ensemble_list[i]
            output_list[i, :, :] = inn(X)
            # tmpp = torch.cat([p.data.view(-1) for p in inn.parameters()])
            wb_list[i, :] = torch.cat([p.data.view(-1) for p in inn.parameters()])

        return output_list, wb_list

    def save(self, checkpoint_path):
        torch.save(self.state_dict(), checkpoint_path)

    def load(self, checkpoint_path):
        self.load_state_dict(torch.load(checkpoint_path, map_location=lambda storage, loc: storage))

