from __future__ import print_function
import numpy as np
import matplotlib.pyplot  as plt
import matplotlib.patches as patches

from mpl_toolkits.mplot3d import Axes3D # <--- This is important for 3d plotting
from matplotlib.patches import Rectangle, Circle, Arrow
from matplotlib.ticker import NullLocator
import torch
from torch import tensor
import torch.nn as nn


##########################################################
# Models for land navigation and planning
##########################################################


class Model_land(nn.Module):
    def __init__(self,):
        super(Model_land, self).__init__()

        self.target_center = tensor([6.0, 4.5])
        self.target_radius = 0.5

        self.unsafe1_x = 4.0
        self.unsafe1_y = 1.5
        self.unsafe1_center = tensor([self.unsafe1_x, self.unsafe1_y])
        self.unsafe1_width = 0.5
        self.unsafe1_length = 1.5

        #self.unsafe1_center = tensor([0.4, 0.5])
        #self.unsafe1_radius = 0.3162
        self.unsafe1_coeff = 10 #1

        self.unsafe2_x = 1.5
        self.unsafe2_y = 4.5
        self.unsafe2_center = tensor([self.unsafe2_x, self.unsafe2_y])
        self.unsafe2_width = 1.5
        self.unsafe2_length = 0.5

        #self.unsafe2_center = tensor([0.8, 1.5])
        #self.unsafe2_radius = 0.3162
        self.unsafe2_coeff = 10 #0.1

        self.start_center = tensor([0.0, 0.0])
        self.start_radius = 0.7

        self.time_penalty = - 1.0

        self.first_plot = True
        self.trajectory = []

    def reset_trajectory(self):
        self.trajectory = []

    def save_trajectory_state(self, next_state):
        self.trajectory.append(next_state.tolist())

    def render(self):
        # Refer1: https://stackoverflow.com/questions/4098131/how-to-update-a-plot-in-matplotlib
        # Refer2: https://stackoverflow.com/questions/9215658/plot-a-circle-with-pyplot

        traj = np.array(self.trajectory).T
        x = traj[0]
        y = traj[1]

        if self.first_plot:
            # You probably won't need this if you're embedding things in a tkinter plot...
            plt.ion()

            # Set the figure window for the first time
            self.fig = plt.figure()
            self.ax = self.fig.add_subplot(111)
            self.ax.cla()  # clear things for fresh plot

            # change default range so that new circles will work
            self.ax.set_xlim((-1.0, 7.0))
            self.ax.set_ylim((-1.0, 7.0))

            # Plot the unsafe regions
            #circle_unsafe1 = plt.Circle(self.unsafe1_center.tolist(), self.unsafe1_radius, color='r')
            #circle_unsafe2 = plt.Circle(self.unsafe2_center.tolist(), self.unsafe2_radius, color='r')
            # rect1 = patches.Rectangle((3.5, 0.0), 1.0, 3.0, linewidth=1, edgecolor='r')
            # rect2 = patches.Rectangle((0.0, 4.0), 3.0, 1.0, linewidth=1, edgecolor='r')

            rect1 = patches.Rectangle((self.unsafe1_x - self.unsafe1_width, self.unsafe1_y - self.unsafe1_length),
                                      2*self.unsafe1_width, 2*self.unsafe1_length, linewidth=1, edgecolor='r')
            rect2 = patches.Rectangle((self.unsafe2_x - self.unsafe2_width, self.unsafe2_y - self.unsafe2_length),
                                      2*self.unsafe2_width, 2*self.unsafe2_length, linewidth=1, edgecolor='r')

            # Plot the target set
            circle_target = plt.Circle(self.target_center.tolist(), self.target_radius, color='g')

            # Plot the initialization set
            circle_start = plt.Circle(self.start_center.tolist(), self.start_radius, color='b', fill=False)

            # Actually plot all the circles
            self.ax.add_artist(rect1)
            self.ax.add_artist(rect2)
            self.ax.add_artist(circle_target)
            self.ax.add_artist(circle_start)

            # Create the trajectory
            # Returns a tuple of line objects, thus the comma
            self.line1, = self.ax.plot(x, y, color='k', linestyle='dashed', marker='o')

            self.first_plot = False

        else:
            # Update the trajectory plot in the environment
            self.line1.set_ydata(y)
            self.line1.set_xdata(x)

        self.fig.canvas.draw()
        self.fig.canvas.flush_events()

    def _dist(self, pos, center, r):
        return torch.sum((pos - center)**2) - r**2

    def _dist_rec(self, pos, center, wid, len):
        if (center[0] - wid) <= pos[0] and (center[0] + wid) >= pos[0]:
            if (center[1] - len) <= pos[1] and (center[1] + len) >= pos[1]:
                return torch.sum((pos - center)**2) - wid**2 - len**2
        #         return -10.0

        return 0.0

    def get_reward(self, next_state, state):
        # Reward is maximum when the object is closest to the destination
        # and farthest from the unsafe regions.

        d1 = self._dist_rec(next_state, self.unsafe1_center, self.unsafe1_width, self.unsafe1_length)
        d2 = self._dist_rec(next_state, self.unsafe2_center, self.unsafe2_width, self.unsafe2_length)

        reward = - self._dist(next_state, self.target_center, self.target_radius)\
                 + self.unsafe1_coeff * d1  \
                 + self.unsafe2_coeff * d2  \
                 + self.time_penalty

        # unsafe_flag = (d1 < 0) or (d2 < 0)

        for alpha in [0.0, 0.2, 0.4, 0.6, 0.8]:
            state_middle = alpha * state + (1-alpha) * next_state
            d1 = self._dist_rec(state_middle, self.unsafe1_center, self.unsafe1_width, self.unsafe1_length)
            d2 = self._dist_rec(state_middle, self.unsafe2_center, self.unsafe2_width, self.unsafe2_length)
            unsafe_flag = (d1 < 0) or (d2 < 0)

            if unsafe_flag:
                break

        return reward, unsafe_flag

    def start_state(self):
        # sample a point from circle with
        # center = (0, 0) and radius = 0.7
        r = np.random.rand() * self.start_radius
        theta = np.random.rand() * 2 * np.pi
        state = tensor([r * np.cos(theta), r * np.sin(theta)]) + self.start_center

        self.save_trajectory_state(state)
        return state

    def check_termination(self, state):
        if self._dist(state, self.target_center, self.target_radius) <= 0 \
                or state[0] <= -1 or state[0] >= 7 \
                or state[1] <= -1 or state[1] >= 7 :
            return True

        return False




class TrueModel_land(Model_land):
    def __init__(self, hidden=8):
        super(TrueModel_land, self).__init__()

    def get_next_state_and_reward(self, state, action):
        # Unpack state and action
        x, y = state
        v, theta = action

        # Compute the next state
        x_ = x + v * torch.cos(theta) * 0.1
        y_ = y + v * torch.sin(theta) * 0.1
        next_state = torch.stack([x_, y_], dim=0)

        # print(v.item(), theta.item(), (x_-x).item())

        # next_state = state + action
        next_state = torch.clamp(next_state, min=-1, max=+7)    # Cannot go outside env boundary
        reward, unsafe_flag = self.get_reward(next_state, state)

        self.save_trajectory_state(next_state)

        return next_state, reward, {"unsafe": unsafe_flag}




class EstimatedModel_land(Model_land):
    def __init__(self, config, hidden=8):
        super(EstimatedModel_land, self).__init__()
        self.config = config

        self.fc1 = nn.Linear(2, hidden)
        self.fc_state = nn.Linear(hidden, 2)
        # self.fc_reward = nn.Linear(hidden, 1)

        # Initialize the SGD (or other variant..) optimizer
        self.optim = self.config.optim(self.parameters(), lr=self.config.model_lr)

    def forward(self, action):
        h = torch.tanh(self.fc1.forward(action))          # Compute the hidden representation
        return torch.tanh(self.fc_state.forward(h)) * 0.5       # Range = [-0.5, 0.5]

    def get_next_state_and_reward(self, state, action):
        delta = self.forward(action)
        next_state = state + delta
        next_state = torch.clamp(next_state, min=-1, max=+7)  # Range = [-1, 7]
        reward, unsafe_flag = self.get_reward(next_state, state)

        # _, unsafe_flag_middle = self.get_reward(state + delta/2)
        # print(action.tolist(), (next_state-state)[0].item())


        self.save_trajectory_state(next_state)

        return next_state, reward, {"unsafe": unsafe_flag}

    def update(self, loss, retain_graph=False, clip_norm=False):
        self.optim.zero_grad()                          # Reset the gradients
        loss.backward(retain_graph=retain_graph)        # Let pytorch do the backprop magic
        self._step(clip_norm)                           # Use gradients to update the parameters

    def _step(self, clip_norm):
        if clip_norm:
            torch.nn.utils.clip_grad_norm_(self.parameters(), clip_norm)
        self.optim.step()

    def save(self, filename):
        torch.save(self.state_dict(), filename)

    def load(self, filename):
        self.load_state_dict(torch.load(filename))

    def reset(self):
        return



##########################################################
# Models for air navigation and planning
##########################################################


class Model_air(nn.Module):
    def __init__(self,):
        super(Model_air, self).__init__()

        self.target_corner = tensor([3.4, 4.1, 4.0])
        self.target_width = 0.8
        self.target_length = 0.8
        self.target_height = 1.0
        # self.target_delta = tensor([self.target_width, self.target_length, self.target_height])
        # self.target_center = self.target_corner + self.target_delta/2.0

        # self.unsafe1_xyz = (1.5, 1.5, 0.0)
        # self.unsafe1_width = 1.0
        # self.unsafe1_length = 1.0
        # self.unsafe1_height = 6.0
        # self.unsafe1_coeff = 10

        self.unsafe1_xyz = tensor((4.0, 3.0, 0.0))
        self.unsafe1_radius = 0.5
        self.unsafe1_height = 6.0
        self.unsafe1_coeff = 10

        self.unsafe2_xyz = tensor((2.0, 4.0, 0.0))
        self.unsafe2_radius = 0.5
        self.unsafe2_height = 6
        self.unsafe2_coeff = 10

        self.start_corner = tensor([-1.0, -1.0, 0.0])
        self.start_width = 2.0
        self.start_length = 2.0
        self.start_height = 0.2

        self.time_penalty = - 1.0

        self.first_plot = True
        self.trajectory = []

    def reset_trajectory(self):
        self.trajectory = []

    def save_trajectory_state(self, next_state):
        self.trajectory.append(next_state.tolist())

    def render(self):
        # raise NotImplementedError
        # Refer1: https://stackoverflow.com/questions/4098131/how-to-update-a-plot-in-matplotlib
        # Refer2: https://stackoverflow.com/questions/9215658/plot-a-circle-with-pyplot

        traj = np.array(self.trajectory).T
        x = traj[0]
        y = traj[1]
        z = traj[2]

        if self.first_plot:
            # You probably won't need this if you're embedding things in a tkinter plot...
            plt.ion()

            # Set the figure window for the first time
            self.fig = plt.figure()
            self.ax = self.fig.add_subplot(111, projection='3d')
            self.ax.cla()  # clear things for fresh plot

            # change default range so that new circles will work
            self.ax.set_xlim((-1.5, 9.0))
            self.ax.set_ylim((-1.5, 9.0))
            self.ax.set_zlim((-1.5, 9.0))

            # Plot the unsafe regions
            #circle_unsafe1 = plt.Circle(self.unsafe1_center.tolist(), self.unsafe1_radius, color='r')
            #circle_unsafe2 = plt.Circle(self.unsafe2_center.tolist(), self.unsafe2_radius, color='r')
            # rect1 = patches.Rectangle((3.5, 0.0), 1.0, 3.0, linewidth=1, edgecolor='r')
            # rect2 = patches.Rectangle((0.0, 4.0), 3.0, 1.0, linewidth=1, edgecolor='r')
            #
            # Plot the target set
            # circle_target = plt.Circle(self.target_center.tolist(), self.target_radius, color='g')

            # Plot the initialization set
            # circle_start = plt.Circle(self.start_center.tolist(), self.start_radius, color='b', fill=False)

            # Actually plot all the circles
            # self.ax.add_artist(rect1)
            # self.ax.add_artist(rect2)
            # self.ax.add_artist(circle_target)
            # self.ax.add_artist(circle_start)

            # Create the trajectory
            # Returns a tuple of line objects, thus the comma
            self.line1 = self.ax.plot(x, y, z, color='k', linestyle='dashed', marker='o')[0]

            self.first_plot = False

        else:
            # pass
            # Update the trajectory plot in the environment
            print(self.line1)
            self.line1 = self.ax.plot(x, y, z, color='k', linestyle='dashed', marker='o')[0]
            # self.line1.set_data_3d(x, y, z)
            # self.line1.set_zdata(z)
            # self.line1.set_ydata(y)
            # self.line1.set_xdata(x)

        self.fig.canvas.draw()
        self.fig.canvas.flush_events()


    def _dist_sph(self, pos, center, r, h):
        x,y,z = center
        return torch.sum((pos - center)**2) - r**2

    def _dist_cyl(self, pos, center, r, h):
        x,y,z = center
        if pos[2] >= z and pos[2] <= (z + h):
            a = torch.sum((pos[:2] - center[:2])**2) - r**2
            if a < 0:
                return a
                # b = (pos[2] - (z + h/2))**2 - (h/2)**2

        return 0.0

    def _dist_rec(self, pos, corner, wid, len, height):
        x,y,z = corner
        offset = tensor([wid/2, len/2, height/2])
        center = corner + offset
        if pos[0] > x and pos[0] < (x + wid):
           if pos[1] > y and pos[1] < (y + len):
               if pos[2] > z and pos[2] < (z + height):
                   return torch.sum((pos - center) ** 2) - torch.sum(offset**2)
                   # return -10.0

        return 0.0

    def _dist_rec_diff(self, pos, corner, wid, len, height):
        offset = tensor([wid / 2, len / 2, height / 2])
        center = corner + offset
        return torch.sum((pos - center) ** 2) - torch.sum(offset ** 2)


    def get_reward(self, next_state, state):
        # Reward is maximum when the object is closest to the destination
        # and farthest from the unsafe regions.

        # d1 = self._dist_rec(next_state, self.unsafe1_xyz, self.unsafe1_width, self.unsafe1_length, self.unsafe1_height)
        d1 = self._dist_cyl(next_state, self.unsafe1_xyz, self.unsafe1_radius, self.unsafe1_height)
        d2 = self._dist_cyl(next_state, self.unsafe2_xyz, self.unsafe2_radius, self.unsafe2_height)

        # Need to take distance from a spherical target to ensure that there is at least one differentiable
        # reward value at each time step
        reward = - self._dist_rec_diff(next_state, self.target_corner, self.target_width, self.target_length, self.target_height) \
                 + self.unsafe1_coeff * d1 \
                 + self.unsafe2_coeff * d2 \
                 + self.time_penalty

        # unsafe_flag = (d1 < 0.0) or (d2 < 0.0) or (d3 < 0.0)

        for alpha in [0.0, 0.2, 0.4, 0.6, 0.8]:
            state_middle = alpha * state + (1 - alpha) * next_state
            d1 = self._dist_cyl(state_middle, self.unsafe1_xyz, self.unsafe1_radius, self.unsafe1_height)
            d2 = self._dist_cyl(state_middle, self.unsafe2_xyz, self.unsafe2_radius, self.unsafe2_height)
            unsafe_flag = (d1 < 0) or (d2 < 0)

            if unsafe_flag:
                break

        return reward, unsafe_flag

    def start_state(self):
        # sample a point from circle with
        # center = (0, 0) and radius = 0.2
        # r = np.random.rand() * self.start_radius
        # theta = np.random.rand() * 2 * np.pi
        # state = tensor([r * np.cos(theta), r * np.sin(theta), 0]) + self.start_center

        x = np.random.rand() * self.start_width
        y = np.random.rand() * self.start_length
        z = np.random.rand() * self.start_height

        state = self.start_corner + tensor([x, y, z])

        self.save_trajectory_state(state)
        return state

    def check_termination(self, state):
        if self._dist_rec(state, self.target_corner, self.target_width, self.target_length, self.target_height) < 0 \
                or state[0] <= -1.5 or state[0] >= 8 \
                or state[1] <= -1.5 or state[1] >= 7 \
                or state[2] <= 0.0 or state[2] >= 6 :
            return True

        return False




class TrueModel_air(Model_air):
    def __init__(self, hidden=8):
        super(TrueModel_air, self).__init__()

    def get_next_state_and_reward(self, state, action):
        # Unpack state and action
        x, y, z = state
        v, theta, phi = action

        # Compute the next state
        x_ = x + v * torch.cos(theta) * torch.cos(phi) * 0.1
        y_ = y + v * torch.cos(theta) * torch.sin(phi) * 0.1
        z_ = z + v * torch.sin(theta) * 0.1
        next_state = torch.stack([x_, y_, z_], dim=0)

        # print(v.item(), theta.item(), (x_-x).item())

        # next_state = state + action
        # next_state = torch.clamp(next_state, min=-1.5, max=+9)    # Cannot go outside env boundary

        reward, unsafe_flag = self.get_reward(next_state, state)

        self.save_trajectory_state(next_state)

        return next_state, reward, {"unsafe": unsafe_flag}




class EstimatedModel_air(Model_air):
    def __init__(self, config, hidden=8):
        super(EstimatedModel_air, self).__init__()
        self.config = config

        self.fc1 = nn.Linear(3, hidden)
        self.fc_state = nn.Linear(hidden, 3)
        # self.fc_reward = nn.Linear(hidden, 1)

        # Initialize the SGD (or other variant..) optimizer
        self.optim = self.config.optim(self.parameters(), lr=self.config.model_lr)

    def forward(self, action):
        h = torch.tanh(self.fc1.forward(action))          # Compute the hidden representation
        return torch.tanh(self.fc_state.forward(h)) * 0.5       # Range = [-0.5, 0.5]

    def get_next_state_and_reward(self, state, action):
        delta = self.forward(action)
        next_state = state + delta
        # next_state = torch.clamp(next_state, min=-1.5, max=+9)  # Range = [-1.5, 9]
        reward, unsafe_flag = self.get_reward(next_state, state)

        # print(action.tolist(), (next_state-state)[0].item())


        self.save_trajectory_state(next_state)

        return next_state, reward, {"unsafe": unsafe_flag}

    def update(self, loss, retain_graph=False, clip_norm=False):
        self.optim.zero_grad()                          # Reset the gradients
        loss.backward(retain_graph=retain_graph)        # Let pytorch do the backprop magic
        self._step(clip_norm)                           # Use gradients to update the parameters

    def _step(self, clip_norm):
        if clip_norm:
            torch.nn.utils.clip_grad_norm_(self.parameters(), clip_norm)
        self.optim.step()

    def save(self, filename):
        torch.save(self.state_dict(), filename)

    def load(self, filename):
        self.load_state_dict(torch.load(filename))

    def reset(self):
        return


