import numpy as np
import gym
from gym import spaces
import matplotlib.pyplot as plt

import utils

class ContinuousGridWorld(gym.Env):
    def __init__(self, width=10, height=10, walls=True):
        super(ContinuousGridWorld, self).__init__()
        
        self.width = width
        self.height = height
        self.goal = np.array([width - 1, height - 1])
        
        self.action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
        self.observation_space = spaces.Box(low=0, high=max(width, height), shape=(2,), dtype=np.float32)
        
        self.state = None
        self.fig, self.ax = plt.subplots()
        self.walls = []
        if walls:
            self.walls = self._create_walls()
        self.timestep = 0
        self.reset()

    def reset(self):
        self.state = np.array([0.0, 0.0])
        self.timestep = 0
        return self.state

    def step(self, action):
        norm = np.linalg.norm(action)
        
        if norm > 1:
            action = action / norm

        new_state = np.clip(self.state + action, 0, [self.width, self.height])

        if not self._is_collision(self.state, new_state):
            self.state = new_state
        
        distance_to_top_left = np.linalg.norm(self.state - np.array([self.width , self.height]))
        reward = -distance_to_top_left 
        done = False  
        self.timestep += 1
        
        return self.state, reward, done, {}
    
    def _create_walls(self):
        """Define walls as a list of line segments [(x1, y1, x2, y2), ...]"""
        walls = [
            (2, 2, 4, 2),  # Horizontal wall
            (2, 2, 2, 6)   # Vertical wall
        ]
        return walls
    
    def _is_collision(self, start, end):
        """Check if moving from start to end collides with any wall"""
        for wall in self.walls:
            if self._line_intersect(start, end, (wall[0], wall[1]), (wall[2], wall[3])):
                return True
        return False

    def _line_intersect(self, p1, p2, q1, q2):
        """Check if line segment p1p2 intersects with q1q2"""
        def ccw(A, B, C):
            return (C[1] - A[1]) * (B[0] - A[0]) > (B[1] - A[1]) * (C[0] - A[0])
        
        return ccw(p1, q1, q2) != ccw(p2, q1, q2) and ccw(p1, p2, q1) != ccw(p1, p2, q2)
    

    def render_plt(self):
        plt.figure(figsize=(8, 6))
        plt.xlim(0, self.width)
        plt.ylim(0, self.height)
        
        plt.grid(True, which='both')
        plt.xticks(np.arange(0, self.width+1, 1))
        plt.yticks(np.arange(0, self.height+1, 1))

        plt.scatter(self.goal[0], self.goal[1], color='green', s=100, marker='o', label='Goal')
        
        plt.scatter(self.state[0], self.state[1], color='red', s=100, marker='o', label='Agent')
        
        plt.title('Continuous Grid World')
        plt.xlabel('X Position')
        plt.ylabel('Y Position')
        plt.legend()
        plt.show()
    
    def render(self, mode='human', show_reward=False):
        fig, ax = plt.subplots(figsize=(8, 6))

        x = np.linspace(0, self.width, 100)
        y = np.linspace(0, self.height, 100)
        
        if show_reward:
            X, Y = np.meshgrid(x, y)
            rewards = -np.sqrt((X - (self.width))**2 + (Y - (self.height))**2)  
            ax.imshow(rewards, extent=[0, self.width, 0, self.height], origin='lower', cmap='viridis', interpolation='nearest')
            fig.colorbar(ax.imshow(rewards, extent=[0, self.width, 0, self.height], origin='lower', cmap='viridis', interpolation='nearest'), ax=ax, label='Reward')

        # Draw the walls
        for wall in self.walls:
            ax.plot([wall[0], wall[2]], [wall[1], wall[3]], color='black', linewidth=3)

        ax.set_xlim(0, self.width)
        ax.set_ylim(0, self.height)
        ax.grid(True, which='both')
        ax.set_xticks(np.arange(0, self.width+1, 1))
        ax.set_yticks(np.arange(0, self.height+1, 1))
                

        ax.scatter(self.state[0], self.state[1], color='red', s=100, marker='o')
        
        ax.set_title(f'Continuous Grid Environment - Timestep {self.timestep}')
        ax.set_xlabel('X Position')
        ax.set_ylabel('Y Position')


        fig.canvas.draw()
        img = np.array(fig.canvas.renderer.buffer_rgba())

        plt.close(fig) 

        return img

    def close(self):
        plt.close(self.fig)


if __name__ == "__main__":

    env = ContinuousGridWorld()
    state = env.reset()
    done = False

    record = []
    obs = env.reset()
    done = False
    record.append(env.render())
    while not done:
        action = env.action_space.sample()
        obs, reward, done, extra = env.step(action)
        record.append(env.render())
        print(f"State: {obs}, Reward: {reward}, Done: {done}")
        if env.timestep >= 20:
            done = True
    utils.create_gif_from_rgba_arrays(record, name= "Test")
    env.close()
