import torch
from torch.functional import F
from torch import nn
from old_torch_init import trunc_normal_
from torch.autograd import Variable


class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)


def soft_clamp(x: torch.Tensor, _min=None, _max=None, device=torch.device("cpu")):
    # clamp tensor values while mataining the gradient
    # device = torch.device("cuda:1")
    # device = torch.device("cpu")
    x = x.to(device)
    if _max is not None:
        _max = _max.to(device)
        x = (_max - F.softplus(_max - x))
    if _min is not None:
        _min = _min.to(device)
        x = (_min + F.softplus(x - _min))
    return x


class EnsembleLinear(torch.nn.Module):
    def __init__(self, in_features, out_features, ensemble_size=7, device=torch.device('cpu'), cuda_index=0):
        super().__init__()

        self.ensemble_size = ensemble_size  #7

        if device != torch.device("cpu"):
            self.register_parameter('weight', torch.nn.Parameter(
                torch.zeros(ensemble_size, in_features, out_features).cuda(cuda_index)))
            self.register_parameter('bias', torch.nn.Parameter(
                torch.zeros(ensemble_size, 1, out_features).cuda(cuda_index)))
        else:
            self.register_parameter('weight', torch.nn.Parameter(torch.zeros(ensemble_size, in_features, out_features)))
            self.register_parameter('bias', torch.nn.Parameter(torch.zeros(ensemble_size, 1, out_features)))

        # torch.nn.init.trunc_normal_(self.weight, std=1 / (2 * in_features ** 0.5))
        trunc_normal_(self.weight, std=1 / (2 * in_features ** 0.5))

        self.select = list(range(0, self.ensemble_size))

    def forward(self, x, dev):
        device = dev
        weight = self.weight[self.select].to(device)
        bias = self.bias[self.select].to(device)
        x = x.to(device)

        if len(x.shape) == 2:
            x = torch.einsum('ij,bjk->bik', x, weight)
        else:
            x = torch.einsum('bij,bjk->bik', x, weight)

        x = x + bias

        return x

    def set_select(self, indexes):
        assert len(indexes) <= self.ensemble_size and max(indexes) < self.ensemble_size
        self.select = indexes


class EnsembleTransition(torch.nn.Module):
    def __init__(self, args, obs_dim, action_dim, hidden_features, hidden_layers, ensemble_size=7, mode='local',
                 with_reward=True, device=torch.device("cpu"), cuda_index=0):
        super().__init__()
        self.alpha = args.alpha
        self.cuda = False
        if device != torch.device("cpu"):
            self.cuda = True
        self.cuda_index = cuda_index
        self.learning_rate = args.learning_rate
        self.device = device

        self.obs_dim = obs_dim
        self.mode = mode
        self.with_reward = with_reward
        self.ensemble_size = ensemble_size  #7

        self.activation = Swish()

        module_list = []
        for i in range(hidden_layers):
            if i == 0:
                module_list.append(EnsembleLinear(obs_dim + action_dim, hidden_features, ensemble_size,
                                                  device=self.device, cuda_index=self.cuda_index))
            else:
                module_list.append(EnsembleLinear(hidden_features, hidden_features, ensemble_size,
                                                  device=self.device, cuda_index=self.cuda_index))
        self.backbones = torch.nn.ModuleList(module_list)

        self.output_layer = EnsembleLinear(hidden_features, 2 * (obs_dim + self.with_reward), ensemble_size,
                                           device=self.device, cuda_index=self.cuda_index)

        if self.cuda:
            self.register_parameter('max_logstd', torch.nn.Parameter(
                torch.ones(obs_dim + self.with_reward).cuda(self.cuda_index) * 1, requires_grad=True))
            self.register_parameter('min_logstd', torch.nn.Parameter(
                torch.ones(obs_dim + self.with_reward).cuda(self.cuda_index) * -5, requires_grad=True))
        else:
            self.register_parameter('max_logstd',
                                    torch.nn.Parameter(torch.ones(obs_dim + self.with_reward) * 1, requires_grad=True))
            self.register_parameter('min_logstd',
                                    torch.nn.Parameter(torch.ones(obs_dim + self.with_reward) * -5, requires_grad=True))

        self.optim = torch.optim.Adam(self.parameters(), lr=args.learning_rate, weight_decay=0.000075)

    def forward(self, obs_action):
        output = obs_action
        for layer in self.backbones:
            output = self.activation(layer(output, self.device))
        mu, logstd = torch.chunk(self.output_layer(output, self.device), 2, dim=-1)
        mu, logstd = mu.to(self.device), logstd.to(self.device)
        logstd = soft_clamp(logstd, self.min_logstd, self.max_logstd, device=self.device)
        if self.mode == 'local':
            if self.with_reward:
                obs, reward = torch.split(mu, [self.obs_dim, 1], dim=-1)
                obs = obs + obs_action[..., :self.obs_dim]
                mu = torch.cat([obs, reward], dim=-1)
            else:
                mu = mu + obs_action[..., :self.obs_dim]
        return torch.distributions.Normal(mu, torch.exp(logstd))

    def set_select(self, indexes):
        for layer in self.backbones:
            layer.set_select(indexes)
        self.output_layer.set_select(indexes)

    def get_torch_variable(self, inp):
        if self.cuda:
            return Variable(inp).cuda(self.cuda_index)
        else:
            return Variable(inp)

    def inner_training(self, task_obs, task_act, next_task_obs, reward, theta, training=True):
        samples = torch.cat((task_obs, task_act), dim=1)
        labels = torch.cat([next_task_obs, reward], dim=-1)
        torch_samples = self.get_torch_variable(samples)
        torch_labels = self.get_torch_variable(labels)
        torch_output = self.forward(torch_samples)

        loss = - torch_output.log_prob(torch_labels)
        loss = loss.mean()

        theta_params = list(theta.parameters())
        phi_params = list(self.parameters())
        norm_sum = torch.zeros(1)
        norm_coef = self.alpha*torch.ones(1)
        if self.cuda:
            norm_sum = norm_sum.cuda(self.cuda_index)
            norm_coef = norm_coef.cuda(self.cuda_index)
        for ind in range(len(phi_params)):
            initial_shape = phi_params[ind].shape
            new_shape = 1
            for j in range(len(initial_shape)):
                new_shape = new_shape * initial_shape[j]
            norm_in = torch.reshape(phi_params[ind], (new_shape,)) - torch.reshape(theta_params[ind], (new_shape,))
            norm_sum += torch.norm(norm_in, 2) ** 2

        torch_cost = loss + 0.01 * self.max_logstd.mean() - 0.01 * self.min_logstd.mean() + norm_coef*norm_sum
        cost = torch_cost.clone().detach().cpu().numpy()

        if training:
            self.optim.zero_grad()
            torch_cost.backward()
            theta.zero_grad()
            self.optim.step()
        return cost

    def test_training(self, task_obs, task_act, next_task_obs, reward, theta, training=True):
        samples = torch.cat((task_obs, task_act), dim=1)
        labels = torch.cat([next_task_obs, reward], dim=-1)
        torch_samples = self.get_torch_variable(samples)
        torch_labels = self.get_torch_variable(labels)
        torch_output = self.forward(torch_samples)

        theta_params = list(theta.parameters())
        phi_params = list(self.parameters())
        norm_sum = torch.zeros(1)
        norm_coef = torch.ones(1)
        if self.cuda:
            norm_sum = norm_sum.cuda(self.cuda_index)
            norm_coef = norm_coef.cuda(self.cuda_index)
        for ind in range(len(phi_params)):
            initial_shape = phi_params[ind].shape
            new_shape = 1
            for j in range(len(initial_shape)):
                new_shape = new_shape * initial_shape[j]
            norm_in = torch.reshape(phi_params[ind], (new_shape,)) - torch.reshape(theta_params[ind], (new_shape,))
            norm_sum += torch.norm(norm_in, 2) ** 2

        torch_cost = ((torch_output.mean - torch_labels) ** 2).mean(dim=(1, 2)) \
                     + norm_coef*norm_sum
        cost = torch_cost.clone().detach().cpu().numpy()

        return cost
