from typing import List, Optional
import numpy as np
import gymnasium as gym
from gymnasium import spaces
from args import Args
import b_plus_tree as bpt
from b_plus_tree import calculate_length_max_depth_of_tree

class BScheduler(gym.Env):
    def __init__(self, args: Args = Args().parse_args([]), render_mode: Optional[str] = None):
        self.num_inserts = args.env_num_inserts
        self.num_deletes =  args.env_num_deletes
        self.num_operations = self.num_inserts + self.num_deletes
        self.max_tree_values = args.env_max_tree_values
        self.max_values_per_node = args.env_max_values_per_node
        self.action_space = spaces.Discrete(self.num_operations)
        self.low = -2
        self.high = self.max_tree_values + self.num_operations
        self.len_tree_obs_space, self.max_possible_tree_depth = calculate_length_max_depth_of_tree(
            self.max_tree_values, self.max_values_per_node
        )
        self.observation_space = spaces.Box(
            low=self.low,
            high=self.high,
            shape=(self.len_tree_obs_space + self.num_operations,),
            dtype=np.float32,
        )
        self.tree = None
        self.tree_representation = None
        self.inserts = None
        self.deletes = None
        self.rng = np.random.default_rng(None)
        

    def _get_obs(self):
        self.tree_representation = self.tree.get_obs_space_feature_representation(
            self.max_possible_tree_depth
        )
        assert len(self.tree_representation) == self.len_tree_obs_space
        assert self.tree_representation[-1] == 0
        return np.concatenate([self.operations, self.tree_representation], dtype=np.float32)

    def reset(self, seed: Optional[int] = None, options: Optional[dict] = None):
        super().reset(seed=seed)
        if seed is not None:
            self.rng = np.random.default_rng(seed=seed)
        self.tree = bpt.BPlusTree(maximum=self.max_values_per_node)
        self.tree_numbers = self.rng.choice(
            a=np.arange(1, self.high), size=self.max_tree_values + self.num_inserts, replace=False
        )
        inserts = self.tree_numbers[self.max_tree_values:]
        for i in range(self.max_tree_values - self.num_inserts):
            self.tree.insert(self.tree_numbers[i], self.tree_numbers[i])

        deletes = self.rng.choice(
            self.tree_numbers[:self.num_deletes], self.num_deletes, replace=False
        )

        sorted_deletes = np.sort(deletes)
        sorted_inserts = np.sort(inserts)
        self.tree.calculate_reward()  # to reset counters
        self.operations = np.concatenate([sorted_inserts, sorted_deletes])
        return self._get_obs(), {}

    def step(self, action):
        info = {}
        truncated = False
        terminated = False
        reward = 0
        operation = self.operations[action]
        if operation == -1:
            print("action", action)
            print("operation", operation)
            print("operations", self.operations)
            raise "This should not happen if you use MaskablePPO"
 
        elif action < self.num_operations // 2:
            self.tree.insert(operation, operation)
        else:
            self.tree.delete(operation)
        self.operations[action] = -1

        if (self.operations == -1).all():
            terminated = True

        observation = self._get_obs()
        reward = -1 * self.tree.calculate_reward()
        return observation, reward, terminated, truncated, info

    def action_masks(self) -> List[bool]:
        ret = self.operations != -1
        return ret


gym.register(
    id="BScheduler-v0",
    entry_point=BScheduler,
)