from pathlib import Path

import gym
import numpy as np
from matplotlib import pyplot as plt
from envs.autoencoder import autoencoder
import torch

class TonesEnv(gym.Env):
    metadata = {
        'render_modes': ['rgb_array'],
        'render_fps': 10,
    }

    observation_space: gym.spaces.Box
    action_space = gym.spaces.Discrete(3, start=1)  # (1, 2, 3)

    def __init__(self, seq_len: int, num_interval_on_left: tuple[int, int], num_interval_on_right: tuple[int, int],
                 reward_turning_wrong_way_at_the_end: int = -5, reward_turning_correct_way_at_the_end: int = 5,
                 reward_going_backward: int = -3, reward_turning_into_wall: int = 0, reward_turning_into_back_wall: int = -5,
                 _max_episode_steps: int = 500, pixel_output: bool = False, pixel_output_shape: tuple[int, int] = (9, 9), decaying_walls: bool = False, decaying_rate_walls: float = 0.4, flag_end_wall: bool = True, encode_obs_using_autoencoder: bool = False):
        """
        Summary:

        Gym environment inspired by the experiment in https://www.nature.com/articles/s41586-021-03652-7

        The agent can move along a straight column with walls on two sides. A finite number of towers are placed on each of the walls.
        The goal of the agent is to choose to take steps towards the end of the column and turn in the direction of the wall with the most number of towers.
        At every step the agent is allowed to choose an action from the given set of actions - [step backward, step forward, step right, step left]

        Args:
            seq_len (int): length of the sequence
            num_interval_on_left (tuple[int, int]): random sampling of the number of towers on the left from the given interval
            num_interval_on_right (tuple[int, int]): random sampling of the number of towers on the right from the given interval
            reward_turning_wrong_way_at_the_end (int, optional): reward received for making the wrong decision at the end (kept negative). Defaults to -5.
            reward_turning_correct_way_at_the_end (int, optional): reward received for making the correct decision at the end (kept positive). Defaults to 5.
            reward_going_backward (int, optional): reward received for choosing the action of going backwards (kept negative). Defaults to -3.
            reward_turning_into_wall (int, optional): reward received for turning into a wall. Defaults to 0.
            reward_turning_into_back_wall (int, optional): reward received for turning into the wall behind the start position (kept negative). Defaults to -5.
            _max_episode_steps (int, optional): maximum steps in an episode. Defaults to 500.
            pixel_output (bool, optional): if True, the observation returned is a nxn image which would reflect the first person view of the rat, and if False, returns a list of integers for the current step described by [left wall, center wall, right wall] (left and right walls are 1 if there is a tower on the respective wall else 0. The center wall is 1 when the agent reaches the end of track and finds a wall in front of it). Defaults to False.
            pixel_output_shape (tuple[int, int], optional): shape of the nxn image if the pixel_output is True. Defaults to (9, 9).
            decaying_walls (bool, optional): if True, the shape of the towers change as the agent gets closer to the wall similar to the real-life agent (only works if pixel_output is set to True). Defaults to False.
            decaying_rate_walls (float, optional): the rate at which the shape of the towers should change. Defaults to 0.4.
            flag_end_wall (bool, optional): This flag specifies whether the agent sees the final wall in the front. (Whether the final wall in the middle lights up at the end of the track. If True it lights up, if False, then it does not). Defaults tp True.
            encode_obs_using_autoencoder (bool, optional): If true, the obs is the encoding of an encoder (which is a part of a trained autoencoder). Defaults to False.
        """

        self.seq_len = seq_len
        self.num_interval_on_left = num_interval_on_left
        self.num_interval_on_right = num_interval_on_right
        self.decaying_walls_flag = decaying_walls
        self.decaying_rate_walls = decaying_rate_walls
        self.flag_end_wall = flag_end_wall
        self.encode_obs_using_autoencoder = encode_obs_using_autoencoder


        if pixel_output:
            self.observation_space = gym.spaces.Box(low=0, high=1, shape=pixel_output_shape)

            assert self.observation_space.shape[0] != 0
            assert self.observation_space.shape[1] != 0
            assert self.observation_space.shape[1] % 3 == 0, "axis 1 of pixel_output_shape must be divisible by 3."
        else:
            self.observation_space = gym.spaces.Box(low=0, high=1, shape=(3,))

        if encode_obs_using_autoencoder:
            assert pixel_output, "encoder can only be used with pixel output when the pixel output size is (60,60)"
            assert pixel_output_shape[0] == 60 and pixel_output_shape[1] == 60, "encoder can only be used with pixel output when the pixel output size is (60,60)"
            # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # we can just use CPU
            self.model = autoencoder(input_image_shape = (1,60,60), encoder_output_dim = 3)
            checkpoint = torch.load(Path(Path(__file__).parent, 'autoencoder_weights_sigmoid.pt'))
            self.model.load_state_dict(checkpoint['model_state_dict'])
            self.model.eval()
            self.observation_space = gym.spaces.Box(low=0, high=1, shape=(3, 1))
        # 0 - backward
        # 1 - forward
        # 2 - right
        # 3 - left

        reward_list = [
            reward_turning_wrong_way_at_the_end,
            reward_turning_correct_way_at_the_end,
            reward_going_backward,
            reward_turning_into_wall,
            reward_turning_into_back_wall
        ]
        self.reward_range = (min(reward_list), max(reward_list))

        self.correct_side = None
        self.reward_turning_wrong_way_at_the_end = reward_turning_wrong_way_at_the_end
        self.reward_turning_correct_way_at_the_end = reward_turning_correct_way_at_the_end
        self.reward_going_backward = reward_going_backward
        self.reward_turning_into_wall = reward_turning_into_wall
        self.reward_turning_into_back_wall = reward_turning_into_back_wall
        self._max_episode_steps = _max_episode_steps
        self.pixel_output = pixel_output
        self.pixel_output_shape = pixel_output_shape

        self.step_count = 0

    def reset(self, **kwargs):
        # creates an array of length seq_len+1, at index seq_len is the wall, i.e 1 on both sides. the agent starts at index 1 (because the agent could go backwards in the first step)
        # The stimulus starts showing up from the third step.
        super().reset(**kwargs)

        self.current_position = 1  # agent starts a step 1, not 0.
        self.step_count = 0  # resetting the step count

        self.right_wall = np.zeros(self.seq_len + 1)
        self.left_wall = np.zeros(self.seq_len + 1)
        self.center_wall = np.zeros(self.seq_len + 1)
        self.correct_side = None

        while True:
            self.num_on_left = self.np_random.integers(*self.num_interval_on_left)  # Num of stimulus on left
            self.num_on_right = self.np_random.integers(*self.num_interval_on_right)  # Num of stimulus on right

            if self.num_on_left != self.num_on_right:
                break

        if self.num_on_right > self.num_on_left:
            self.correct_side = 2
        else:
            self.correct_side = 3
                    # Add a check that num_on_left and num_on_right are always less than seq_length

        index_on_left = self.np_random.choice(range(3, self.seq_len), self.num_on_left, replace=False)  # pillars added from third step
        index_on_right = self.np_random.choice(range(3, self.seq_len), self.num_on_right, replace=False)  # pillars added from third step

        self.left_wall[index_on_left] = 1
        self.right_wall[index_on_right] = 1
        self.center_wall[self.seq_len] = 1

        # The following block of code should remove the final wall in the case of the flag being turned on
        if self.flag_end_wall == False:
            self.center_wall[self.seq_len] = 0

        # to help fix the decay rate
        # plt.plot(self.left_wall)
        # plt.show()
        # plt.plot(self.right_wall)
        # plt.show()



        if self.decaying_walls_flag == True:
            self.left_wall_decay = np.zeros(len(self.left_wall))
            self.right_wall_decay = np.zeros(len(self.right_wall))
            self.left_wall = self.left_wall[::-1]
            self.right_wall = self.right_wall[::-1]

            for i in range(0,len(self.left_wall)):
                #print("left -",self.left_wall_decay)
                #print("right -",self.right_wall_decay)

                if i == 0:
                    self.left_wall_decay[0] = self.left_wall[0]
                    self.right_wall_decay[0] = self.right_wall[0]
                else:
                    self.left_wall_decay[i] = np.exp(-self.decaying_rate_walls)*self.left_wall_decay[i-1] + self.left_wall[i]
                    self.right_wall_decay[i] = np.exp(-self.decaying_rate_walls)*self.right_wall_decay[i-1] + self.right_wall[i]

                if self.left_wall_decay[i]>1:
                    self.left_wall_decay[i] = 1
                if self.right_wall_decay[i]>1:
                    self.right_wall_decay[i] = 1

            self.left_wall = self.left_wall[::-1]
            self.right_wall = self.right_wall[::-1]
            self.left_wall_decay = self.left_wall_decay[::-1]
            self.right_wall_decay = self.right_wall_decay[::-1]

            # # Plot decayed walls
            # plt.plot(self.left_wall_decay)
            # plt.show()
            # plt.plot(self.right_wall_decay)
            # plt.show()

        obs = [self.left_wall[self.current_position], self.center_wall[self.current_position], self.right_wall[self.current_position]]


        if self.pixel_output:
            if self.decaying_walls_flag==False:
                width = self.pixel_output_shape[1] // 3
                temp_var = np.zeros(self.pixel_output_shape)
                temp_var[:, 0:width] = obs[0]
                temp_var[:, width:2*width] = obs[1]
                temp_var[:, 2*width:3*width] = obs[2]
                obs = temp_var
            else:

                width = self.pixel_output_shape[1] // 3
                height = self.pixel_output_shape[0] // 3
                #print(int(self.left_wall_decay[self.current_position]*width))
                #print(int(self.right_wall_decay[self.current_position]*width))
                #print()
                temp_var = np.zeros(self.pixel_output_shape)
                temp_var[height-int(height*self.left_wall_decay[self.current_position]):(2*height)+int(height*self.left_wall_decay[self.current_position]), 0:int(self.left_wall_decay[self.current_position]*width)] = 1
                temp_var[:, width:2*width] = obs[1]
                temp_var[height-int(height*self.right_wall_decay[self.current_position]):(2*height)+int(height*self.right_wall_decay[self.current_position]), (3*width - int(self.right_wall_decay[self.current_position]*width)) :3*width] = 1
                obs = temp_var

            if self.encode_obs_using_autoencoder == True:
                obs = np.expand_dims(np.expand_dims(obs, axis=0),axis=0)
                obs = torch.Tensor(obs)
                # print(f"{obs.size()=}")
                with torch.no_grad():
                    obs = self.model.encode(obs)
                    obs = torch.squeeze(obs)
                    obs = obs.cpu().numpy()[:, None]
        return obs

    # At the last step, both walls light up.

    # def create_obs(self):

    def render(self, mode="human"):
        if mode == 'rgb_array':
            frame = np.zeros([3, self.seq_len + 1, 3], dtype="uint8")
            frame[0, :, (0, 1, 2)] = self.left_wall * 255
            frame[1, :, (0, 1, 2)] = self.center_wall * 255
            frame[2, :, (0, 1, 2)] = self.right_wall * 255
            frame[1, self.current_position] = [255, 0, 0]
            return frame
        else:
            raise NotImplementedError

    def step(self, action):
        """
        Returns:
            obs : observation of the environment
            reward : reward for taking action
            done : True if the agent is at terminal state
            info :
                info : is always None at the moment, would like to remove it after notifying
                evidence : gives the evidence required for the experiment
                position : gives current position
                done_reason : this is either None, "correct", "wrong", or "timeout" depending on the reason for done being set to True
        """

        self.step_count += 1
        info = None
        done_reason = None
        if action == 0:  # backward
            if self.current_position != 1:
                self.current_position -= 1
            obs = [self.left_wall[self.current_position], self.center_wall[self.current_position], self.right_wall[self.current_position]]
            reward = self.reward_going_backward
            done = False

        elif action == 1:  # forward
            if self.current_position != self.seq_len:
                self.current_position += 1
                reward = 0
            else:
                reward = self.reward_turning_into_back_wall
            obs = [self.left_wall[self.current_position], self.center_wall[self.current_position], self.right_wall[self.current_position]]
            done = False

        elif action == 2 or action == 3:  # left, right
            if self.current_position == self.seq_len:
                if action == self.correct_side:
                    obs = [0, 0, 0]  # change this
                    reward = self.reward_turning_correct_way_at_the_end
                    done = True
                    done_reason = "correct"
                else:
                    obs = [0, 0, 0]  # change this
                    reward = self.reward_turning_wrong_way_at_the_end
                    done = True
                    done_reason = "wrong"
            else:
                obs = [0, 0, 0]  # hide towers when moving sideways
                reward = self.reward_turning_into_wall
                done = False
        else:
            raise ValueError

        if self.step_count >= self._max_episode_steps:
            done = True
            done_reason = "timeout"

        evidence = self.right_wall[:self.current_position].sum() - self.left_wall[:self.current_position].sum()
        info = {"info": info, "evidence": evidence, "position": self.current_position, "done_reason": done_reason}

        if self.pixel_output:
            if self.decaying_walls_flag==False:
                width = self.pixel_output_shape[1] // 3
                temp_var = np.zeros(self.pixel_output_shape)
                temp_var[:, 0:width] = obs[0]
                temp_var[:, width:2*width] = obs[1]
                temp_var[:, 2*width:3*width] = obs[2]
                obs = temp_var
            else:

                width = self.pixel_output_shape[1] // 3
                height = self.pixel_output_shape[0] // 3
                #print(int(self.left_wall_decay[self.current_position]*width))
                #print(int(self.right_wall_decay[self.current_position]*width))
                #print()
                temp_var = np.zeros(self.pixel_output_shape)
                temp_var[height-int(height*self.left_wall_decay[self.current_position]):(2*height)+int(height*self.left_wall_decay[self.current_position]), 0:int(self.left_wall_decay[self.current_position]*width)] = 1
                temp_var[:, width:2*width] = obs[1]
                temp_var[height-int(height*self.right_wall_decay[self.current_position]):(2*height)+int(height*self.right_wall_decay[self.current_position]), (3*width - int(self.right_wall_decay[self.current_position]*width)) :3*width] = 1
                obs = temp_var

            if self.encode_obs_using_autoencoder == True:
                obs = np.expand_dims(np.expand_dims(obs, axis=0),axis=0)
                obs = torch.Tensor(obs)
                with torch.no_grad():
                    obs = self.model.encode(obs)
                    obs = torch.squeeze(obs)
                    obs = obs.cpu().numpy()[:, None]

        return obs, reward, done, info


if __name__ == "__main__":
    # -------- autoencoder only works for the following setting --------
    env = TonesEnv(
        seq_len = 100,
        num_interval_on_left = (1,15),
        num_interval_on_right = (1,15),
        reward_turning_wrong_way_at_the_end = 0,
        reward_turning_correct_way_at_the_end = 10,
        reward_going_backward = -1,
        reward_turning_into_wall = -0.1,
        reward_turning_into_back_wall = -1,
        _max_episode_steps = 120,
        pixel_output = True,
        pixel_output_shape = (60,60),
        decaying_walls = True, #False,
        decaying_rate_walls = 0.1,
        encode_obs_using_autoencoder=True
    )

    print(env.reset())
    print(env.step(1))
    print(env.step(1))
    print(env.step(1))
    print(env.step(1))
    print(env.step(1))
    print(env.step(1))
    print(env.step(1))
    print(env.step(1))
    print(env.step(1))
    print(env.step(1))
    print(env.step(1))
    print(env.step(1))
    print(env.step(1))
    print(env.step(1))
    print(env.step(1))
    print(env.step(1))
    print(env.step(1))
    print(env.step(1))
    print(env.step(1))
    print(env.step(1))
