from __future__ import print_function
import numpy as np
import torch
import torch.nn as nn


class Controller_land(nn.Module):
    def __init__(self, config, hidden=8):
        super(Controller_land, self).__init__()
        self.config = config

        self.fc1 = nn.Linear(2, hidden)
        self.fc_2 = nn.Linear(hidden, 2)

        self.optim = self.config.optim(self.parameters(), lr=self.config.controller_lr)

        print("Controller: ", [(name, param.shape) for name, param in self.named_parameters()])

    def get_grads(self):

        # print("Controller: ", [(name, param.shape, param.requires_grad) for name, param in self.named_parameters()])
        grads = []
        if self.config.debug:
            for param in self.parameters():
                    grads.append(np.mean(np.abs(param.grad.data.cpu().numpy())))
                    # grads.append(np.mean(np.abs(param.data.cpu().numpy())))
        return grads

    def get_action(self, state):
        # state_action = torch.cat([state, action], dim=0)
        h = torch.tanh(self.fc1.forward(state))
        next_action = self.fc_2.forward(h)

        next_action[0] = torch.sigmoid(next_action[0]) * 5        # Velocity Range = [0, 5]
        next_action[1] = torch.tanh(next_action[1]) * np.pi         # Angle Range = [-pi, +pi]

        return next_action

    def update(self, loss, retain_graph=False, clip_norm=1):
        self.optim.zero_grad()  # Reset the gradients
        loss.backward(retain_graph=retain_graph)
        self._step(clip_norm)

    def _step(self, clip_norm):
        if clip_norm:
            torch.nn.utils.clip_grad_norm_(self.parameters(), clip_norm)
        self.optim.step()

    def save(self, filename):
        torch.save(self.state_dict(), filename)

    def load(self, filename):
        self.load_state_dict(torch.load(filename))

    def reset(self):
        return




class Controller_air(nn.Module):
    def __init__(self, config, hidden=8):
        super(Controller_air, self).__init__()
        self.config = config

        self.fc1 = nn.Linear(3, hidden)
        self.fc_2 = nn.Linear(hidden, 3)

        self.optim = self.config.optim(self.parameters(), lr=self.config.controller_lr)

        print("Controller: ", [(name, param.shape) for name, param in self.named_parameters()])

    def get_grads(self):

        # print("Controller: ", [(name, param.shape, param.requires_grad) for name, param in self.named_parameters()])
        grads = []
        if self.config.debug:
            for param in self.parameters():
                    grads.append(np.mean(np.abs(param.grad.data.cpu().numpy())))
                    # grads.append(np.mean(np.abs(param.data.cpu().numpy())))
        return grads

    def get_action(self, state):
        # state_action = torch.cat([state, action], dim=0)
        h = torch.tanh(self.fc1.forward(state))
        next_action = self.fc_2.forward(h)

        next_action[0] = torch.sigmoid(next_action[0]) * 5        # Velocity Range = [0, 5]
        next_action[1] = torch.tanh(next_action[1]) * np.pi         # Angle theta = [-pi, +pi]
        next_action[2] = torch.tanh(next_action[2]) * np.pi         # Angle phi = [-pi, +pi]

        return next_action

    def update(self, loss, retain_graph=False, clip_norm=1):
        self.optim.zero_grad()  # Reset the gradients
        loss.backward(retain_graph=retain_graph)
        self._step(clip_norm)

    def _step(self, clip_norm):
        if clip_norm:
            torch.nn.utils.clip_grad_norm_(self.parameters(), clip_norm)
        self.optim.step()

    def save(self, filename):
        torch.save(self.state_dict(), filename)

    def load(self, filename):
        self.load_state_dict(torch.load(filename))

    def reset(self):
        return

