import math
from typing import Optional, Tuple, Union

import numpy as np
import os
import pandas as pd
from collections import defaultdict

import gymnasium as gym
from gymnasium import logger, spaces
from gymnasium.envs.classic_control import utils
from gymnasium.error import DependencyNotInstalled
from gymnasium.experimental.vector import VectorEnv
from gymnasium.vector.utils import batch_space
from gym.envs.registration import register
from imitations.data.types import TrajectoryWithRew


class MeerkatEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):


    metadata = {
        "render_modes": ["human", "rgb_array"],
        "render_fps": 50,
    }

    def __init__(self, render_mode: Optional[str] = None):
        self.action_space = spaces.Discrete(25)
        self.observation_space = spaces.Discrete(25)

        self.render_mode = render_mode

        self.screen_width = 600
        self.screen_height = 400
        self.screen = None
        self.clock = None
        self.isopen = True
        self.state = None
        self.timestemp = 0
        self.steps_beyond_terminated = None

    def step(self, action):
        assert self.action_space.contains(action), f"{action!r} ({type(action)}) invalid"
        assert self.state is not None, "Call reset before using step method."

        self.state = action
        self.timestemp += 1

        terminated = bool(self.timestemp >= 29)

        if not terminated:
            reward = 1.0
        else:
            self.timestemp = 0
            reward = 0.0

        return np.array(self.state, dtype=np.int64), reward, terminated, False, {}

    def reset(
        self,
        *,
        seed: Optional[int] = None,
        options: Optional[dict] = None,
    ):
        super().reset(seed=seed)
        self.state = self.np_random.integers(low=0, high=25)
        self.steps_beyond_terminated = None

        if self.render_mode == "human":
            self.render()
        return np.array(self.state, dtype=np.int64), {}

    def rollout(self, folder_path):
        state_action_mapping = defaultdict(lambda: len(state_action_mapping))
        print(state_action_mapping)
        all_trajectories = []

        # Traverse all files in the folder
        for file_name in os.listdir(folder_path):
            if file_name.endswith('.csv'):
                file_path = os.path.join(folder_path, file_name)
                data = pd.read_csv(file_path, header=None, usecols=[2])  # Reading only the third column

                # Process each trajectory chunk of 30 rows
                num_rows = len(data)
                for start in range(0, num_rows, 30):
                    if start + 30 > num_rows:
                        break
                    chunk = data.iloc[start:start + 30, 0].values

                    obs = []
                    acts = []
                    prev_state_id = state_action_mapping[chunk[0]]  # First state

                    for i in range(1, 30):
                        current_state_id = state_action_mapping[chunk[i]]
                        obs.append(prev_state_id)
                        acts.append(current_state_id)
                        prev_state_id = current_state_id

                    # Create trajectory object
                    trajectory = TrajectoryWithRew(
                        obs=obs,
                        acts=acts,
                        rews=np.ones(len(acts), dtype=np.float32),  # Reward of 1 for each action
                        terminal=True
                    )
                    all_trajectories.append(trajectory)

        return all_trajectories


class MeerkatVectorEnv(VectorEnv):
    metadata = {
        "render_modes": ["human", "rgb_array"],
        "render_fps": 50,
    }

    def __init__(
        self,
        num_envs: int = 8,
        max_episode_steps: int = 500,
        render_mode: Optional[str] = None,
    ):
        super().__init__()
        self.num_envs = num_envs
        self.max_episode_steps = max_episode_steps

        self.single_action_space = spaces.Discrete(25)
        self.action_space = batch_space(self.single_action_space, num_envs)
        self.single_observation_space = spaces.Discrete(25)
        self.observation_space = batch_space(self.single_observation_space, num_envs)

        self.steps = np.zeros(num_envs, dtype=np.int32)
        self.render_mode = render_mode

        self.screen_width = 600
        self.screen_height = 400
        self.screen = None
        self.clock = None
        self.isopen = True
        self.state = None
        self.timestemp = 0
        self.steps_beyond_terminated = None

    def step(self, action: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, dict]:
        assert self.action_space.contains(
            action
        ), f"{action!r} ({type(action)}) invalid"
        assert self.state is not None, "Call reset before using step method."

        self.state = np.stack(action)

        terminated: np.ndarray = ((self.timestemp >= 29) | (self.steps >= self.max_episode_steps))
        self.steps += 1
        truncated = self.steps >= self.max_episode_steps

        done = terminated | truncated

        if any(done):
            # This code was generated by copilot, need to check if it works
            self.state[:, done] = self.np_random.random(25).astype(np.int64)
            self.steps[done] = 0

        reward = np.ones_like(terminated, dtype=np.float32)

        return self.state.T, reward, terminated, truncated, {}

    def reset(
            self,
            *,
            seed: Optional[int] = None,
            options: Optional[dict] = None,
    ):
        super().reset(seed=seed)
        self.state = self.np_random.randint(low=0, high=25, size=self.num_envs, dtype=np.int64)
        self.steps_beyond_terminated = None

        return self.state, {}

    def rollout(self):
        folder_path = r'.\imitations\env\output3'
        all_states = set()
        all_trajectories = []

        # Traverse all files in the folder to collect all states
        for file_name in os.listdir(folder_path):
            if file_name.endswith('.csv'):
                file_path = os.path.join(folder_path, file_name)
                data = pd.read_csv(file_path, header=None, usecols=[2])  # Reading only the third column
                all_states.update(data[2].values)

        # Create a sorted list of unique states and map them to IDs
        sorted_states = sorted(all_states)
        state_action_mapping = {state: idx for idx, state in enumerate(sorted_states)}

        # Traverse all files again to process trajectories
        for file_name in os.listdir(folder_path):
            if file_name.endswith('.csv'):
                file_path = os.path.join(folder_path, file_name)
                data = pd.read_csv(file_path, header=None, usecols=[2])  # Reading only the third column

                # Process each trajectory chunk of 30 rows, ensuring 31 observations
                num_rows = len(data)
                for start in range(0, num_rows - 1, 30):  # Subtract 1 to ensure we can always grab an extra obs
                    end = start + 30
                    if end >= num_rows:  # Ensure there is at least one more observation beyond the last action
                        break
                    chunk = data.iloc[start:end + 1, 0].values  # Plus one to include the additional observation

                    obs = []
                    acts = []
                    prev_state_id = state_action_mapping[chunk[0]]  # First state

                    for i in range(1, 30):
                        current_state_id = state_action_mapping[chunk[i]]
                        obs.append(prev_state_id)
                        acts.append(current_state_id)
                        prev_state_id = current_state_id

                    # Append the extra observation
                    obs.append(prev_state_id)

                    # Ensure observations and actions are numpy arrays
                    obs_array = np.array(obs, dtype=np.int64)
                    acts_array = np.array(acts, dtype=np.int64)
                    rews_array = np.ones(len(acts), dtype=np.float32)  # Reward of 1 for each action

                    # Create trajectory object using the imported class
                    trajectory = TrajectoryWithRew(
                        obs=obs_array,
                        acts=acts_array,
                        rews=rews_array,
                        terminal=True,
                        infos=None
                    )
                    all_trajectories.append(trajectory)

        return all_trajectories













