import numpy as np
import torch
import random
from omegaconf import DictConfig
import matplotlib.pyplot as plt
from scipy.interpolate import make_interp_spline
import gym
from gym import spaces
from gym.utils import seeding


def pd_controller(current_pos, target_pos, prev_error, Kp, Kd):
    """
    PD controller to determine the velocity based on the current error and the rate of change of the error.
    """
    error = np.array(target_pos) - np.array(current_pos)
    derivative = error - prev_error
    velocity = Kp * error + Kd * derivative
    return velocity, error


def simulate_reference_trajectory_with_noise(start_pos, goal_pos, grid_size, n_simulation_steps, Kp, Kd, noise_scale):
    """
    Calculate the reference trajectory based on PD control with noise.
    """
    trajectory = [start_pos]
    prev_error = np.random.normal(0, noise_scale, size=2)

    for i in range(1, n_simulation_steps):
        t = i / (n_simulation_steps - 1)
        target_pos = [start_pos[0] + (goal_pos[0] - start_pos[0]) * t, start_pos[1] + (goal_pos[1] - start_pos[1]) * t]
        current_pos = trajectory[-1]
        velocity, error = pd_controller(current_pos, target_pos, prev_error, Kp, Kd)
        noisy_velocity = velocity + np.random.normal(0, noise_scale, size=2)
        new_pos = np.array(current_pos) + noisy_velocity
        # new_pos = np.clip(new_pos, 0, grid_size - 1)  # Cap position within grid boundaries
        trajectory.append(new_pos.tolist())
        prev_error = error
    return trajectory


def interpolate_trajectory(x, y, n_points):
    """
    Interpolate the trajectory with a spline.
    """
    step_numbers = np.linspace(0, 1, len(x))
    spline_x = make_interp_spline(step_numbers, x, k=3)
    spline_y = make_interp_spline(step_numbers, y, k=3)
    steps_smooth = np.linspace(0, 1, n_points)
    x_smooth = spline_x(steps_smooth)
    y_smooth = spline_y(steps_smooth)
    return x_smooth, y_smooth


def plant(current_pos, current_vel, velocity_command, time_step=1.0):
    """
    Simple plant function to update position and velocity.
    """
    new_vel = velocity_command
    new_pos = current_pos + new_vel * time_step
    return new_pos, new_vel


class Diagonal2dEnvironment(gym.Env):
    """
    Custom Environment that simulates a 2D diagonal movement.
    """
    def __init__(self, start_positions=[[0, 0], [4, 0]], goal_positions=[[4, 4], [0, 4]], noise_scale=0.12, 
                 episode_len=50, repeat_observation=10, grid_size=5):
        super(Diagonal2dEnvironment, self).__init__()
        self.start_positions = start_positions
        self.goal_positions = goal_positions
        self.goal_locations = None  # For consistency with d4rl environments
        self.noise_scale = noise_scale
        self.n_frames = episode_len
        self.repeat_observation = repeat_observation
        self.grid_size = grid_size
        self.current_pos = None
        self.current_vel = None
        self.step_index = 0

        assert len(self.start_positions) == len(self.goal_positions)

        # Define action and observation space
        self.action_space = spaces.Box(low=-np.inf, high=np.inf, shape=(2,), dtype=np.float32)
        self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(2 * self.repeat_observation,),
                                             dtype=np.float32)

    def reset(self):
        choice = np.random.randint(0, len(self.start_positions))
        self.start_pos = np.array(self.start_positions[choice]) + np.random.normal(0, self.noise_scale, size=2)
        self.goal_pos = np.array(self.goal_positions[choice])
        # self.start_pos = np.array([0, 0])
        # self.goal_pos = np.array([0, 4])
        self.goal_locations = [tuple(self.goal_pos.tolist())]

        self.current_pos = np.array(self.start_pos)
        self.current_vel = np.array([0, 0])
        self.step_index = 0

        return np.tile(np.array(self.current_pos), self.repeat_observation)

    def step(self, action):
        velocity_command = np.array(action)
        self.current_pos, self.current_vel = plant(self.current_pos, self.current_vel, velocity_command)

        self.step_index += 1
        done = self.step_index == self.n_frames
        reward = 1.0 if np.linalg.norm(self.goal_positions[0] - self.current_pos) < 0.5 else 0.0

        return np.tile(np.array(self.current_pos), self.repeat_observation), reward, done, {}
    
    def seed(self, seed=None):
        self.np_random, seed = seeding.np_random(seed)
        return [seed]
    
    @property
    def str_maze_spec(self):
        lines = []
        lines.append(f"{'#' * (self.grid_size + 2)}")
        for y in range(self.grid_size):
            line = "#"
            for x in range(self.grid_size):
                if (x, y) == tuple(self.start_pos):
                    line += "S"
                elif (x, y) == tuple(self.goal_pos):
                    line += "G"
                else:
                    line += " "
            line += "#"
            lines.append(line)
        lines.append(f"{'#' * (self.grid_size + 2)}")
        return "\\".join(lines)


class PDAgent:
    """
    PD agent that generates reference trajectory and follows it.
    """
    def __init__(self, cfg):
        self.cfg = cfg
        self.grid_size = cfg.grid_size
        self.Kp = cfg.Kp
        self.Kd = cfg.Kd
        self.ref_trajectory = None
        self.prev_error = np.array([0, 0])
        self.noise_scale = cfg.noise_scale
        self.n_frames = cfg.episode_len
        self.n_simulation_steps = cfg.n_simulation_steps

    def generate_ref_trajectory(self, start_pos, goal_pos):
        # Generate reference trajectory using the provided function
        ref_trajectory = simulate_reference_trajectory_with_noise(
            start_pos, goal_pos, self.grid_size, self.n_simulation_steps, self.Kp, 
            self.Kd, self.noise_scale
        )
        x_ref, y_ref = zip(*ref_trajectory)
        x_ref, y_ref = interpolate_trajectory(x_ref, y_ref, self.n_frames)
        ref_trajectory = list(zip(x_ref, y_ref))
        self.ref_trajectory = ref_trajectory

    def get_action(self, current_pos, current_vel, step_index):
        # Get the reference position and velocity
        ref_pos = self.ref_trajectory[step_index]

        # Compute the action (velocity command) using PD controller
        velocity_command, error = pd_controller(
            current_pos, ref_pos, self.prev_error, self.Kp, self.Kd)
        self.prev_error = error

        return velocity_command


class Diagonal2dOfflineRLDataset(torch.utils.data.IterableDataset):
    def __init__(self, cfg: DictConfig, split: str = "training"):
        self.cfg = cfg
        self.start_positions = list(list(pos) for pos in cfg.start_positions)
        self.goal_positions = list(list(pos) for pos in cfg.goal_positions)
        self.noise_scale = cfg.noise_scale
        self.n_frames = cfg.episode_len
        self.env = Diagonal2dEnvironment(self.start_positions, self.goal_positions, self.noise_scale, self.n_frames, 
                                         cfg.repeat_observation)
        self.agent = PDAgent(cfg)
        self.n_trajectories = len(self.env.start_positions)

    def __iter__(self):
        while True:
            # Run an episode
            obs = self.env.reset()
            start_pos = self.env.start_pos
            goal_pos = self.env.goal_pos
            self.agent.generate_ref_trajectory(start_pos, goal_pos)

            observations, actions, rewards, nonterminals = [], [], [], []

            done = False
            while not done:
                current_pos, current_vel = obs[:2], obs[2:]
                action = self.agent.get_action(current_pos, current_vel, self.env.step_index)
                obs, reward, done, _ = self.env.step(action)

                observations.append(obs.tolist())
                actions.append(action.tolist())
                rewards.append(reward)
                nonterminals.append(not done)

            observations = torch.tensor(observations, dtype=torch.float32)
            actions = torch.tensor(actions, dtype=torch.float32)
            rewards = torch.tensor(rewards, dtype=torch.float32)
            nonterminals = torch.tensor(nonterminals, dtype=torch.bool)

            yield observations, actions, rewards, nonterminals


if __name__ == "__main__":
    from unittest.mock import MagicMock
    import os

    os.chdir("../..")
    cfg = MagicMock()
    cfg.grid_size = 5
    cfg.start_positions = [[0, 0], [4, 0]]
    cfg.goal_positions = [[4, 4], [0, 4]]
    cfg.noise_scale = 0.12
    cfg.n_simulation_steps = 12
    cfg.episode_len = 50
    cfg.repeat_observation = 1
    cfg.Kp = 1.0
    cfg.Kd = 0.3

    bg_color = 'lightgray'
    grid_color = 'white'
    cmap_name = 'Purples'
    cmap_start = 0.3
    linewidth = 2

    ds = Diagonal2dOfflineRLDataset(cfg)
    plt.figure(figsize=(6, 6))
    plt.gca().set_facecolor(bg_color)
    cmap = plt.get_cmap(cmap_name)
    all_o = []
    all_a = []
    all_r = []
    for _ in range(10):
        o, a, r, nt = next(ds.__iter__())
        all_o.append(o)
        all_a.append(a)
        all_r.append(r)
        for i in range(1, cfg.episode_len):
            color_value = cmap_start + (1 - cmap_start) * (i / len(o))  # Shift color map start
            plt.plot(o[i-1:i+1, 0], [o[i-1, 1], o[i, 1]], color=cmap(color_value), lw=linewidth)
            # plt.scatter(o[i, 0], o[i, 1], color=cmap(color_value), s=linewidth, zorder=3)

    plt.grid(True, color=grid_color, linewidth=linewidth)
    # plt.xticks(range(cfg.grid_size))
    # plt.yticks(range(cfg.grid_size))
    plt.xlim([-1, cfg.grid_size])
    plt.ylim([-1, cfg.grid_size])
    plt.title("Data")
    plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False, labelbottom=False, 
                    labelleft=False)
    plt.gca().spines['top'].set_linewidth(linewidth)
    plt.gca().spines['right'].set_linewidth(linewidth)
    plt.gca().spines['bottom'].set_linewidth(linewidth)
    plt.gca().spines['left'].set_linewidth(linewidth)
    # plt.gca().invert_yaxis()
    plt.show()

    all_o = torch.cat(all_o, dim=0)
    all_a = torch.cat(all_a, dim=0)
    all_r = torch.cat(all_r, dim=0)
    print(all_o.shape, all_a.shape, all_r.shape)
    print('-' * 40)
    print(all_o.min(0), all_o.max(0), all_o.mean(0), all_o.std(0))
    print('-' * 40)
    print(all_a.min(0), all_a.max(0), all_a.mean(0), all_a.std(0))
    print('-' * 40)
    print(all_r.min(), all_r.max(), all_r.mean(), all_r.std())
