import numpy as np
from PIL import Image
import torch
import os

from gridworld.map_generator import GridWorldWithRewards


class PersuadableGWState:
    def __init__(self, position: torch.Tensor, angle_id: torch.Tensor):
        super().__init__()
        self.position = position
        self.angle_id = angle_id

    def get_state_shape(self):
        return {"position": [2], "angle_id": [1]}

    def get_state_dict(self):
        return {
            "position": self.position,
            "angle_id": self.angle_id,
        }
    def clone(self):
        return PersuadableGWState(self.position.clone(), self.angle_id.clone())


class PersuadableGWEnvironment:
    default_map_config = {
        "size": 200,
        "num_rooms": 9,
        "room_scale": 2,
        "passage_scale": 2,
        "seed" : 192,
        "num_extra_passages" : 1
    }
    
    
    def __init__(
        self,
        device,
        step_size=1.0,
        num_rotation_steps=8,
        scan_resolution=256,
        scan_range=100,
        scan_angle_range=2 * torch.pi,
        observation_types=["lidar"],
        rotation_invariant_observation=False,
        action_space="rot",
        scale=None,
        patch_size=5, # used for top-down images
        map_config=default_map_config
    ):
        self.device = device
        self.step_size = step_size  # also equals the "size" of the agent (radius)
        
        self.world = GridWorldWithRewards(**map_config)
        self.map = torch.tensor(self.world.grid > 0, device=self.device).int()
        
        self.observation_types = observation_types
        self.scan_resolution = scan_resolution
        self.scan_range = scan_range
        self.scan_angle_range = scan_angle_range
        self.num_rotation_steps = num_rotation_steps
        self.action_angles = [
            (2 * torch.pi / num_rotation_steps) * i for i in range(num_rotation_steps)
        ]
        self.scale = 1 if scale is None else scale
        self.rotation_invariant_observation = rotation_invariant_observation
        self.action_space = action_space
        if action_space == "rot":
            self.step = self._step_angle
        else:
            self.step = self._step_dir

     
        self.patch_size = patch_size
        available_positions = torch.nonzero(self.map == 1) # compute this once

        self.valid_positions = available_positions[torch.logical_and(available_positions[:, 0] % self.step_size == 0, available_positions[:, 1] % self.step_size == 0)].float()

        self.env_state = PersuadableGWState(
            position=self._get_random_pos(),
            angle_id=torch.tensor(0, device=self.device),
        )   
    def get_reward_pos(self, reward_function, position):
        return reward_function(int(position[0].item()), int(position[1].item()))

    def get_reward(self, reward_function, env_state=None):
        if env_state is None:
            env_state = self.env_state
        return reward_function(int(env_state.position[0].item()), int(env_state.position[1].item()))
    
    def _get_img_rgb(self, path):
        current_dir = os.path.dirname(os.path.abspath(__file__))
        file_path = os.path.join(current_dir, "maps", path)
        img = Image.open(file_path)
        img_rgb = np.array(img.convert("RGB"))
        img_rgb = torch.tensor(img_rgb, device=self.device).sum(dim=-1).transpose(0, 1)
        img_rgb[img_rgb > 0] = 1  # binary
        return img_rgb

    def _get_random_pos(self, batch_size=1):
        idx = torch.randint(0, self.valid_positions.shape[0], (batch_size,), device=self.device)
        return self.valid_positions[idx].reshape((-1,))
    


    def _scan_lines(self, position, start_angle, range, angle_range):

        if self.rotation_invariant_observation:
            start_angle = 0

        start_a_radians = start_angle - angle_range / 2
        end_a_radians = (
            start_angle + angle_range / 2 - angle_range / self.scan_resolution
        )

        angles = torch.linspace(
            start_a_radians, end_a_radians, self.scan_resolution, device=self.device
        )
        scan_step = 1

        dx = torch.cos(angles)
        dy = torch.sin(angles)

        step_indices = torch.arange(range, device=self.device)
        all_x = torch.outer(dx, step_indices * scan_step) + position[0]
        all_y = torch.outer(dy, step_indices * scan_step) + position[1]

        all_x = torch.round(all_x).int()
        all_y = torch.round(all_y).int()
        all_x = torch.clamp(all_x, 0, self.map.shape[0] - 1)
        all_y = torch.clamp(all_y, 0, self.map.shape[1] - 1)

        occupied = self.map[all_x, all_y]
        no_detection = occupied.sum(dim=-1) == range

        distance = occupied.argmin(dim=-1)
        distance[no_detection] = range

        return distance.float() / range

    def _step_dir(self, action, env_state=None):
        action = action.argmax(dim=-1)
        if action == 0:
            dy = self.step_size
            dx = 0
        elif action == 1:
            dy = 0
            dx = self.step_size
        elif action == 2:
            dy = -self.step_size
            dx = 0
        elif action == 3:
            dy = 0
            dx = -self.step_size
        else:
            dy = 0
            dx = 0

        if env_state is None:
            env_state = self.env_state
            
        new_pos = env_state.position + torch.tensor([dx, dy], device=self.device)
        new_pos = torch.clamp(
            new_pos,
            torch.tensor([0, 0], device=self.device),
            torch.tensor(
                [self.map.shape[0] - 1, self.map.shape[1] - 1], device=self.device
            ),
        )
        new_pos = torch.round(new_pos).int()
        if not self._check_collision(env_state.position, new_pos):
            env_state.position = new_pos

        return self.get_observation(env_state)

    def _step_angle(self, action):
        action = action.argmax(dim=-1)
        angle = self.action_angles[self.env_state.angle_id]
        if action == 0:
            dy = -self.step_size * np.cos(angle)
            dx = -self.step_size * np.sin(angle)
            new_pos = self.env_state.position + torch.tensor(
                [dx, dy], device=self.device
            )
            new_pos = torch.clamp(
                new_pos,
                torch.tensor([0, 0], device=self.device),
                torch.tensor(
                    [self.map.shape[0] - 1, self.map.shape[1] - 1], device=self.device
                ),
            )
            new_pos = torch.round(new_pos).int()

            if not self._check_collision(self.env_state.position, new_pos):
                self.env_state.position = new_pos

        elif action == 1:
            dy = self.step_size * np.cos(angle)
            dx = self.step_size * np.sin(angle)

            new_pos = self.env_state.position + torch.tensor(
                [dx, dy], device=self.device
            )
            new_pos = torch.clamp(
                new_pos,
                torch.tensor([0, 0], device=self.device),
                torch.tensor(
                    [self.map.shape[0] - 1, self.map.shape[1] - 1], device=self.device
                ),
            )
            new_pos = torch.round(new_pos).int()
            if not self._check_collision(self.env_state.position, new_pos):
                self.env_state.position = new_pos

        elif action == 2:
            self.env_state.angle_id = (
                self.env_state.angle_id + 1
            ) % self.num_rotation_steps
        elif action == 3:
            self.env_state.angle_id = (
                self.env_state.angle_id - 1
            ) % self.num_rotation_steps

        return self.get_observation()

    def _get_random_state(self):
        agent_pos = self._get_random_pos()
        if self.action_space == "rot":
            angle_id = torch.randint(0, self.num_rotation_steps, ())
        else:
            angle_id = torch.tensor(0)
        env_state = PersuadableGWState(position=agent_pos, angle_id=angle_id)
        return env_state

    def _check_collision(self, start, end, num_points=100):
        def get_line(start, end, num_points):
            vectors = torch.linspace(0, 1, num_points, device=self.device).view(-1, 1)
            interpolated_points = start + vectors * (end - start)
            points = interpolated_points.view(num_points, -1)
            return points.int()

        points = get_line(start, end, num_points)
        color = self.map[points[:, 0], points[:, 1]]
        if torch.any(color == 0):
            return True
        return False

    def _binary_map(self):
        img = 1 - self.map.detach().cpu().numpy()
        return img.astype(np.uint8)

    def simulate_trajectory(self, trajectory_length, env_state=None):
        """
        Simulates a trajectory in the environment for a given length.
        """
        if env_state is None:
            env_state = self._get_random_state()
            
        obs_sequence = []
        pos_sequence = []
        angle_sequence = []
        
        # Reset the environment with the provided or random state
        current_state = self.reset(env_state)
        
        # Store initial state
        obs_sequence.append(current_state["observation"].reshape(-1))
        pos_sequence.append(self.env_state.position.reshape(2))
        angle_sequence.append(self.env_state.angle_id.reshape(1))
        
        # Simulate the trajectory
        for t in range(trajectory_length - 1):  # -1 because we already added the initial state
            # Choose action based on strategy or random
            tmp_env_state = PersuadableGWState(position=self.env_state.position.clone(), angle_id=self.env_state.angle_id.clone())
            # Random action
            while True:
                action = torch.zeros(self.get_action_shape(), device=self.device)
                action[torch.randint(0, self.get_action_shape(), (1,), device=self.device)] = 1
                self._step_dir(action, tmp_env_state)
                if not torch.all(tmp_env_state.position == self.env_state.position):
                    break
            
            # Take step in environment
            current_state = self.step(action)
            
            # Store current state
            obs_sequence.append(current_state["observation"].reshape(-1))
            pos_sequence.append(self.env_state.position.reshape(2))
            angle_sequence.append(self.env_state.angle_id.reshape(1))
        
        return {
            "observation": torch.stack(obs_sequence, dim=0).unsqueeze(0),
            "position": torch.stack(pos_sequence, dim=0).unsqueeze(0),
            "angle_id": torch.stack(angle_sequence, dim=0).unsqueeze(0)
        }
    
    
    def sample_valid_positions(self, sample_size=1000, resolution=None):
        return self.valid_positions[torch.randperm(self.valid_positions.shape[0], device=self.device)][:sample_size]

    
    def get_observation(self, env_state=None):
        if env_state is None:
            env_state = self.env_state
        output = []
        if "lidar" in self.observation_types:
            data = self._scan_lines(
                env_state.position,
                self.action_angles[env_state.angle_id],
                range=self.scan_range,
                angle_range=self.scan_angle_range,
            )
            output.append(data.reshape((1, -1)))
        if "position" in self.observation_types:
            output.append(
                env_state.position.reshape((1, -1)).float()
                / torch.tensor(
                    [self.map.shape[0], self.map.shape[1]], device=self.device
                ).reshape((1, -1))
            )
        if "angle" in self.observation_types:
            output.append(
                torch.tensor(
                    env_state.angle_id / (self.num_rotation_steps - 1),
                    device=self.device,
                ).reshape((1, -1))
            )
        if "topdown" in self.observation_types:
            output.append(self.get_topdown_observation(env_state).reshape((1,-1)))
            
        output = torch.cat(output, dim=1)
        return {"observation": output}

    def get_topdown_observation(self, env_state):
        tensor_map = self.map.clone()
        center_x = env_state.position[1].int()
        center_y = env_state.position[0].int()
        patch_size = self.patch_size
        
    # Calculate the half size of the patch
        half_size = patch_size // 2
        
        # Calculate the boundaries of the patch
        start_x = center_x - half_size
        end_x = center_x + half_size + (patch_size % 2)
        start_y = center_y - half_size
        end_y = center_y + half_size + (patch_size % 2)
        
        # Calculate padding needed
        pad_left = max(0, -start_x)
        pad_right = max(0, end_x - tensor_map.shape[1])
        pad_top = max(0, -start_y)
        pad_bottom = max(0, end_y - tensor_map.shape[0])
        
        # Adjust boundaries after calculating padding
        start_x = max(0, start_x)
        start_y = max(0, start_y)
        end_x = min(tensor_map.shape[1], end_x)
        end_y = min(tensor_map.shape[0], end_y)
        
        # Extract the patch
        patch = tensor_map[start_y:end_y, start_x:end_x]
        
        # Add padding if necessary
        if pad_left > 0 or pad_right > 0 or pad_top > 0 or pad_bottom > 0:
            patch = torch.nn.functional.pad(patch, (pad_left, pad_right, pad_top, pad_bottom))
        
        # Draw agent
        agent_size = int(self.step_size // 2)
        if agent_size == 0: # step_size = 1
            patch[half_size, half_size] = 2
        else:
            patch[(half_size-agent_size):(half_size+agent_size),(half_size-agent_size):(half_size+agent_size)] = 2
        return patch / 2


    def sample_new_goal(self, batch_size=1, trajectory_length=1):
        state = PersuadableGWState(self.env_state.position.clone(), self.env_state.angle_id.clone())
        if trajectory_length > 1:
            if batch_size > 1:
                return NotImplementedError
            result = self.simulate_trajectory(trajectory_length=trajectory_length, env_state=None)
            self.reset(state)
            return result
        else:
                
            obs_batch, pos_batch, angle_batch = [], [], []
            for i in range(batch_size):
                rnd_env_state = self._get_random_state()
                obs_batch.append(
                    self.get_observation(env_state=rnd_env_state)["observation"].reshape(
                        (1,-1)
                    )
                )
                pos_batch.append(rnd_env_state.position.reshape((1,2)))
                angle_batch.append(rnd_env_state.angle_id.reshape((1,1)))
            return {
                "observation": torch.stack(obs_batch, dim=0),
                "position": torch.stack(pos_batch, dim=0),
                "angle_id": torch.stack(angle_batch, dim=0),
            }


    def reset(self, env_state=None):
        if env_state is None:
            self.env_state = self._get_random_state()
        else:
            self.env_state = env_state
        return self.get_observation()

    def render_map(self):
        map = self.map.detach().cpu().numpy()
        img = np.zeros((map.shape[0], map.shape[1], 3))
        img[map == 1] = (0, 0, 0)  # (160, 160, 160)
        img[map == 0] = (255, 255, 255)  # (224, 224, 224)
        return img.astype(np.uint8)

    def get_observation_shape(self):
        shape = 0
        if "lidar" in self.observation_types:
            shape += self.scan_resolution
        if "position" in self.observation_types:
            shape += 2
        if "angle" in self.observation_types:
            shape += 1
        if "topdown" in self.observation_types:
            shape += self.patch_size*self.patch_size
        return shape
    
    def get_action_shape(self):
        return 4
