"""Rapidly exploring Random Graph Solver for the Environment for Path Planning Problem

author: Anonymous
"""
import os
import sys
import inspect

currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0, parentdir) 

import math
import random
import numpy as np
from collections import OrderedDict
import json

import networkx as nx

from scheduling.location import Location
from scheduling.obstacle import Obstacle

from path_planning.path import Path
from path_planning.edge import Edge

# from environment import SchedulingEnvironment

import matplotlib.pyplot as plt
import matplotlib.lines as lines

class RRG:
    def __init__(self, env, save_file=None, num_neighbours=5, num_samples=100, mode='euclidean'):
        self.env = env
        
        self.num_neighbours = num_neighbours
        self.num_samples = num_samples
        self.mode = mode
        # The graph for locations that are explored
        self.graph = nx.Graph()
        # a path between each agent and task pair
        self.agent_to_task_paths = {}
        # a path between each task and task pair
        self.task_to_task_paths = {}
        
        self.shared_edge_set = {}
        
        # Check if a save file is provided
        if save_file is None:
            # If no save file is provided, generate paths
            self.generate_paths()
        elif os.path.exists(save_file):
            # If a save file exists, load the paths from the save file
            self.load(save_file)
            self.fully_connected = (len(self.agent_to_task_paths) == self.env.num_agents * self.env.num_tasks) and (len(self.task_to_task_paths) == self.env.num_tasks * (self.env.num_tasks - 1) // 2)
        else:
            # If a save file is provided but doesn't exist, generate paths and save them
            self.generate_paths()
            self.save(save_file) 
        
    def generate_paths(self):
        """Implement the RRG algorithm to generate paths for the agents and tasks"""
        self._build_initial_graph()
        self._run_rrg()
        self._get_missing_paths()
        self._refine_graph()
            
    def _add_edge(self, graph, start, end):
        """Add an edge to the graph between the start and end locations"""
        # check for collision between the start and end locations
        if self._has_collision(start, end):
            return False
        # add the edge to the graph
        graph.add_edge(start, end)
        # add the length of the edge to the graph as an edge attribute
        graph.edges[start, end]['length'] = start.distance(end)
        return True
    
    def _build_initial_graph(self):
        """Build the initial graph for the environment, see if direct route are possible, and add as much as you can.
        """
        # place the agents and tasks on the graph
        for agent_id, agent in enumerate(self.env.agents):
            self.graph.add_node(agent.location)
        for task_id, task in enumerate(self.env.tasks):
            self.graph.add_node(task.location)
        
        # add edges between the agents and agents for direct routes, no need to store the path
        for agent_id, agent in enumerate(self.env.agents):
            for _, other_agent in enumerate(self.env.agents[agent_id+1:]):
                self._add_edge(self.graph, agent.location, other_agent.location)
    
        # add edges between the agents and tasks for direct routes
        for agent_id, agent in enumerate(self.env.agents):
            for task_id, task in enumerate(self.env.tasks):
                if self._add_edge(self.graph, agent.location, task.location):
                    path = Path([agent.location, task.location])
                    self.agent_to_task_paths[(agent_id, task_id)] = path
        # add edges between the tasks for direct routes
        for task_id, task in enumerate(self.env.tasks):
            for other_task_id, other_task in enumerate(self.env.tasks[task_id+1:]):
                if self._add_edge(self.graph, task.location, other_task.location):
                    path = Path([task.location, other_task.location])
                    self.task_to_task_paths[(task_id, other_task_id)] = path
        
    def _run_rrg(self):
        """Generate a set of random points and connect them to the graph"""
        for _ in range(self.num_samples):
            # sample a point on the map
            new_point = self._sample_from_map()
            # Nearest neighbours ordered by distance in ascending order
            nearest_points = self._ordered_neighbours_by_distance(self.graph, new_point)
            if not nearest_points:
                continue
            # get the closest num_neighbours points to the new point
            buffer = []
            keys = list(nearest_points.keys())
            keys.sort()
            for key in keys:
                nearest_neighbour_set = nearest_points[key]
                buffer.extend(nearest_neighbour_set)
                if len(buffer) >= self.num_neighbours:
                    break
            for i, nearest_neighbour in enumerate(buffer):
                # add an edge between the new point and the nearest neighbour
                edge_added = self._add_edge(self.graph, nearest_neighbour, new_point)
                if i == 0 and not edge_added:
                    break

    def _get_missing_paths(self):
        """Build the missing paths for the agents and tasks, and tasks and tasks that do not have a path between them."""
        agent_task_pairings, task_task_pairings = self._get_incomplete_pairings()
        try:
            for agent_id, task_id in agent_task_pairings:
                path = self.get_path(self.env.agents[agent_id].location, self.env.tasks[task_id].location)
                self.agent_to_task_paths[(agent_id, task_id)] = path
            for task_id, other_task_id in task_task_pairings:
                path = self.get_path(self.env.tasks[task_id].location, self.env.tasks[other_task_id].location)
                self.task_to_task_paths[(task_id, other_task_id)] = path
            self.fully_connected = True
        except nx.NetworkXNoPath as e:
            print(e)
            self.fully_connected = False
        return False
    
    def _get_incomplete_pairings(self):
        """Get the incomplete pairings of agents and tasks, and tasks and tasks that do not have a path between them."""
        incomplete_pairings_agent_task = []
        for agent_id, agent in enumerate(self.env.agents):
            for task_id, task in enumerate(self.env.tasks):
                if (agent_id, task_id) not in self.agent_to_task_paths:
                    incomplete_pairings_agent_task.append((agent_id, task_id))
        
        incomplete_pairings_task_task = []
        for task_id in range(len(self.env.tasks)):
            for other_task_id in range(task_id+1, len(self.env.tasks)):
                if (task_id, other_task_id) not in self.task_to_task_paths:
                    incomplete_pairings_task_task.append((task_id, other_task_id))
        return incomplete_pairings_agent_task, incomplete_pairings_task_task

    def _refine_graph(self):
        """Refine the graph to remove unnecessary edges"""
        # points that are used on the graph
        used_points = []
        for path in self.agent_to_task_paths.values():
            used_points.extend(path.get_waypoints())
        for path in self.task_to_task_paths.values():
            used_points.extend(path.get_waypoints())
        # refine the graph to remove unnecessary edges
        # remove the duplicates
        used_points = list(set(used_points))
        # refine the graph to remove unused points
        # the removed points of self.graph.nodes - used_points
        unused_points = [point for point in self.graph.nodes if point not in used_points]
        self.graph.remove_nodes_from(unused_points)        
    
    def _sample_from_map(self):
        """Sample a point from the environment"""
        x = int(np.random.uniform(0, self.env.width))
        y = int(np.random.uniform(0, self.env.height))
        return Location(x, y)
    
    def _ordered_neighbours_by_distance(self, graph, point):
        """Find the nearest points to the given point within the distance
        Args:
            graph (nx.Graph): The graph to search for the nearest points
            point (Location): The point to find the nearest points
        Returns:
            dict: The nearest points to the given point {distance: List[Location]}
        """
        nearest_points = {}
        for other_point in graph.nodes:
            # skip the same point
            if other_point == point:
                continue
            # add to the nearest points
            distance = point.distance(other_point)
            if distance not in nearest_points:
                nearest_points[distance] = []    
            nearest_points[distance].append(other_point)
        return nearest_points
    
    def _has_collision(self, point, other_point):
        """Checks if there is a collision along the straight line between the two points
        Args:
            point (Location): The start point of the line
            other_point (Location): The end point of the line
        Returns:
            bool: True if there is a collision, False otherwise
        """
        for obstacle in self.env.obstacles:
            if obstacle.has_collision(point, other_point):
                return True
        return False
    
    def get_path(self, start, end):
        """Get the path between the start and end points"""
        return Path(nx.shortest_path(self.graph, start, end, weight='length', method='dijkstra'))
        
    def find_shared_edges(self):
        """Find the shared edges between the agent to task and task to task paths, as these will act as spatial constraints."""
        shared_paths = {} # ((location_i, location_j), (location_k, location_l) -> List[Edge]
        all_paths = list(self.agent_to_task_paths.values()) + list(self.task_to_task_paths.values())
        for i, path in enumerate(all_paths):
            for j, other_path in enumerate(all_paths[i+1:]):
                shared_edges = None # path.get_shared_edges(other_path)
                if shared_edges:
                    shared_paths[(path.start, path.end), (other_path.start, other_path.end)] = shared_edges
        return shared_paths
        
    def get_agent_to_task_distance(self, agent_id: int, task_id: int):
        """Get the distance between the agent and task"""
        return self.agent_to_task_paths[(agent_id, task_id)].distance
    
    def get_task_to_task_distance(self, task_id: int, other_task_id: int):
        """Get the distance between the task and other task"""
        return self.task_to_task_paths[(task_id, other_task_id)].distance
    
    def save(self, filename="test/sample_problem_5_10_rrg.json"):
        """Save the RRG to a file"""
        data = {
            "agent_to_task_paths":  self._to_str(self.agent_to_task_paths),
            "task_to_task_paths": self._to_str(self.task_to_task_paths)
        }
        with open(filename, 'w') as f:
            json.dump(data, f, indent=4)
        
    def load(self, filename="test/sample_problem_5_10_rrg.json"):
        """Load the RRG from a file"""
        with open(filename, 'r') as f:
            data = json.load(f)
            self.agent_to_task_paths = self._from_str(data["agent_to_task_paths"])
            self.task_to_task_paths = self._from_str(data["task_to_task_paths"])
        for path in self.agent_to_task_paths.values():
            for edge in zip(path.waypoints[:-1], path.waypoints[1:]):
                self.graph.add_edge(*edge)
                self.graph.edges[edge]['length'] = edge[0].distance(edge[1])
        for path in self.task_to_task_paths.values():
            for edge in zip(path.waypoints[:-1], path.waypoints[1:]):
                self.graph.add_edge(*edge)
                self.graph.edges[edge]['length'] = edge[0].distance(edge[1])
    
    def _to_str(self, path_dictionary):
        """Convert the path dictionary to a string"""
        return {f"{key}": str(value) for key, value in path_dictionary.items()}
    
        
    def _from_str(self, path_dictionary):
        """Convert the path dictionary from a string
        The key is the string representation of the tuple
        """
        return {eval(key): Path([Location.from_str(v) for v in value.split(' -> ')]) for key, value in path_dictionary.items()}
    
    def render(self, filename="figures/rrg.png"):
        """Render the graph as an image and save to the filename"""
        fig, ax = plt.subplots()
        # draw the graph edges on the image as a gray line from one point to the next
        for (start, end) in self.graph.edges:
            line = lines.Line2D([start.x, end.x], [start.y, end.y], color='green', alpha=0.1, label='RRG Paths')
            ax.add_line(line)
        labels = [f'RRG Paths']
        legends = [lines.Line2D([0], [0], color='green', alpha=0.1, label='RRG Paths')]
        # get an agent unique color
        cm = plt.get_cmap('gist_rainbow')
        LINE_STYLES = ['solid', 'dashed', 'dashdot', 'dotted']
        agent_colors = [cm((i + 1)/len(self.env.agents)) if i >= 1 else cm(i/len(self.env.agents)) for i in range(len(self.env.agents))]
        agent_style = [LINE_STYLES[i % len(LINE_STYLES)] for i in range(len(self.env.agents))]
        
        for agent in self.env.agents:
            if not agent.assigned_tasks:
                continue
            plan = agent.assigned_tasks
            # draw the agent's path on the image as a red line from one point to the next
            start_point = agent.location
            path = self.agent_to_task_paths[(agent.id, plan[0].id)]
            for edge in zip(path.waypoints[:-1], path.waypoints[1:]):
                line = lines.Line2D([edge[0].x, edge[1].x], [edge[0].y, edge[1].y], color=agent_colors[agent.id], linestyle=agent_style[agent.id], alpha=1)
                ax.add_line(line)
            current_task = plan[0]
            # print([i.id for i in plan], "\n", [i.id for i in agent.assigned_tasks])
            for task in agent.assigned_tasks[:-1]:
                if current_task.id == task.id:
                    print(f"Task is the same {current_task.id} == {task.id}")
                    continue
                path = self.task_to_task_paths[(current_task.id, task.id)]
                for edge in zip(path.waypoints[:-1], path.waypoints[1:]):
                    line = lines.Line2D([edge[0].x, edge[1].x], [edge[0].y, edge[1].y], color=agent_colors[agent.id], linestyle=agent_style[agent.id], alpha=1)
                    ax.add_line(line)
                current_task = task
            legends.append(lines.Line2D([0], [0], color=agent_colors[agent.id], linestyle=agent_style[agent.id], label=f'Agent {agent.id}'))
            labels.append(f'Agent {agent.id}')
        self.env._render_to_human(fig, ax, legends, labels)
        plt.savefig(filename, bbox_inches='tight')
        plt.close()
