import numpy as np
from gymnasium import spaces
from pettingzoo import ParallelEnv

############################
# Constants / Hyperparams
############################
MAX_STORAGE = 2.0
CHARGE_EFFICIENCY = 0.5
PDM = 20   # Postponed Demand Multiplier
DRM = 20   # Demand / Resource global constraint multiplier

def env_creator(
    data, 
    num_agents: int = 2,
    episode_length: int = 80,
    cooperative: bool = True,
):
    """
    Create the PettingZoo parallel environment.

    data shape: (num_agents, 3, T)
      data[i, 0, :] = demand for agent i
      data[i, 1, :] = generation for agent i
      data[i, 2, :] = price for agent i
    """
    return SmartGridParallelEnv(
        data=data,
        num_agents=num_agents,
        episode_length=episode_length,
        cooperative=cooperative,
    )

class SmartGridParallelEnv(ParallelEnv):
    """
    A PettingZoo ParallelEnv for a multi‐agent smart grid problem,
    matching the logic from the VMAS scenario.

    Key changes from your original:
      - We store `self.start_index` (random each reset).
      - We store `self.current_step` (relative, starts at 0).
      - We use (start_index + current_step) % T to cycle data.
      - We compute reward ~ VMAS scenario: 
            rew0 + lam_rew + post_rew
        with fixed multipliers PDM, DRM.
    """

    metadata = {"render_modes": ["human"], "name": "smart_grid_parallel_v0"}

    def __init__(
        self,
        building_types: list,
        num_agents: int,
        episode_length: int,
        cooperative: bool,
    ):
        super(SmartGridParallelEnv, self).__init__()

        # Basic config
        self._num_agents = num_agents
        self.agents = [f"agent_{i}" for i in range(self._num_agents)]
        self.possible_agents = self.agents[:]
        self.cooperative = cooperative  # store it, even if unused

        # data shape: (num_agents, 3, T) => [demand, generation, price]
        self.data = self.build_data(building_types) 
        self.T = self.data.shape[-1]  # length of timeseries
        self.episode_length = episode_length

        # Global peaks for normalizing
        all_demands = self.data[:, 0, :].flatten()  # all demands from all agents
        all_prices  = self.data[:, 2, :].flatten()  # all prices from all agents
        self.peak_demand = np.max(all_demands)
        self.peak_price  = np.max(all_prices)

        # States per agent
        self.battery_charge = np.zeros(num_agents, dtype=float)
        self.postponed      = np.zeros(num_agents, dtype=float)

        # We'll store a single random start_index for the entire environment,
        # so all agents proceed in sync. If you prefer a separate start per agent,
        # you can store arrays.
        self.start_index = 0
        self.current_step = 0  # relative step

        # Bookkeeping for done/truncation
        self.terminations = {agent: False for agent in self.agents}
        self.truncations  = {agent: False for agent in self.agents}

        # --- Define Observation & Action Spaces ---
        # We'll keep your logic: obs = (battery, demand, price, postponed)
        # Range: 
        #   battery in [0, MAX_STORAGE]
        #   demand in [0, peak_demand]
        #   price in [0, peak_price]
        #   postponed in [0, peak_demand], but you used negative up to peak_demand.
        # We'll just do [0..peak_demand] for safety
        single_obs_low = np.array([0.0, 0.0, 0.0, 0.0], dtype=np.float32)
        single_obs_high = np.array([MAX_STORAGE, self.peak_demand, self.peak_price, self.peak_demand], dtype=np.float32)
        self.single_observation_space = spaces.Box(low=single_obs_low, high=single_obs_high, dtype=np.float32)

        # Action = [consumed_grid, consumed_battery], both in [0..peak_demand], [0..MAX_STORAGE]
        single_act_low  = np.array([0.0,       0.0], dtype=np.float32)
        single_act_high = np.array([self.peak_demand, MAX_STORAGE], dtype=np.float32)
        self.single_action_space = spaces.Box(low=single_act_low, high=single_act_high, dtype=np.float32)

        # Build dicts for each agent
        self.observation_spaces = {
            agent: self.single_observation_space for agent in self.agents
        }
        self.action_spaces = {
            agent: self.single_action_space for agent in self.agents
        }

    def build_data(self, building_types):
        data_list = []
        for building_type in building_types:
            file_path = f"./data//b{building_type}//data.npy"
            building_data = np.load(file_path)
            data_list.append(building_data)
        return np.stack(data_list)

    def reset(self, seed=None, options=None):
        """
        Resets environment state and returns dict of initial obs.
        """
        # If you want reproducibility
        if seed is not None:
            np.random.seed(seed)

        # 1) sample a random start index
        self.start_index = np.random.randint(0, self.T)
        # 2) set current_step=0
        self.current_step = 0

        # Clear states
        self.battery_charge[:] = 0.0
        self.postponed[:]      = 0.0

        # Everyone is not done at the start
        self.terminations = {agent: False for agent in self.agents}
        self.truncations  = {agent: False for agent in self.agents}

        return self._get_obs(), {agent: {} for agent in self.agents}
    
    def observation_space(self, agent):
        return self.single_observation_space

    def action_space(self, agent):
        return self.single_action_space


    def step(self, actions):
        """
        actions: dict of shape {agent: [consumed_grid, consumed_battery], ...}
        Return: obs, rewards, terminations, truncations, infos
        """
        # increment relative step
        self.current_step += 1

        # If we've stepped beyond the horizon
        if self.current_step >= self.episode_length:
            for ag in self.agents:
                self.terminations[ag] = True

        # We'll gather agent-wise reward
        rewards = {}
        infos   = {}

        # We compute index in the data as (start_index + current_step) % T
        data_idx = (self.start_index + self.current_step) % self.T

        # Demand/Gen/Price for each agent at data_idx
        demands = self.data[:, 0, data_idx]  # shape = (num_agents,)
        gens    = self.data[:, 1, data_idx]
        prices  = self.data[:, 2, data_idx]

        for i, agent in enumerate(self.agents):
            # if already done or truncated, no update
            if self.terminations[agent] or self.truncations[agent]:
                rewards[agent] = 0.0
                infos[agent]   = {}
                continue

            # parse actions
            action = actions[agent]
            consumed_grid    = float(action[0])
            consumed_battery = float(action[1])

            # old battery
            old_batt    = self.battery_charge[i]
            old_postp   = self.postponed[i]
            demand_i    = demands[i]
            gen_i       = gens[i]
            price_i     = prices[i]

            # 1) actual battery consumption
            used_battery = min(old_batt, consumed_battery)

            # 2) total demand = demand_i + old_postp
            total_demand = demand_i + old_postp
            new_postp    = total_demand - used_battery - consumed_grid
            new_postp    = max(new_postp, 0.0)

            # 3) battery update
            new_battery = old_batt - used_battery
            # plus generation
            new_battery += CHARGE_EFFICIENCY * (gen_i/1000.0)  # or any scaling needed
            # clamp
            new_battery = min(MAX_STORAGE, max(0.0, new_battery))

            self.battery_charge[i] = new_battery
            self.postponed[i]      = new_postp

            # 4) reward logic from VMAS scenario
            cost = consumed_grid * price_i
            rew0 = - cost / (self.peak_demand * self.peak_price)

         
            # Note that "global_rew" is just per agent's own consumed_grid,
            # effectively "consumed_grid / peak_demand".
            # We'll replicate that exactly for each agent:
            global_rew = - consumed_grid / self.peak_demand
            post_rew   = (PDM * new_postp) / self.peak_demand
            lam_rew    = (DRM * global_rew)

            # final reward
            reward_i = rew0 + lam_rew + post_rew

            rewards[agent] = float(reward_i)
            infos[agent]   = {
                "battery": new_battery,
                "postponed": new_postp,
                "consumed_grid": consumed_grid,
                "used_battery": used_battery,
            }

        # build next obs
        next_obs = self._get_obs()

        # if we also want to end if we pass T
        if (self.current_step >= self.T):
            for ag in self.agents:
                self.terminations[ag] = True

        return next_obs, rewards, self.terminations, self.truncations, infos

    def _get_obs(self):
        """
        Observations: [battery_charge, demand, price, postponed].
        Index = (start_index + current_step) % T
        """
        data_idx = (self.start_index + self.current_step) % self.T
        obs_dict = {}

        # demands shape (num_agents,), etc.
        demands = self.data[:, 0, data_idx]
        prices  = self.data[:, 2, data_idx]

        for i, agent in enumerate(self.agents):
            obs = np.array([
                self.battery_charge[i],
                demands[i],
                prices[i],
                self.postponed[i]
            ], dtype=np.float32)
            obs_dict[agent] = obs
        return obs_dict

    def render(self):
        pass

    def global_state(self):
        """
        If your MAPPO/QMIX wants a centralized state, define it here.
        Example: battery, postponed, demands, prices, gen for all agents
        """
        data_idx = (self.start_index + self.current_step) % self.T
        demands = self.data[:, 0, data_idx]
        gens    = self.data[:, 1, data_idx]
        prices  = self.data[:, 2, data_idx]

        # One possible global state: concat all agent states
        # shape: 5 * num_agents => battery[i], postponed[i], demand[i], price[i], gen[i]
        # Or any aggregator you want
        s_list = []
        for i in range(self.num_agents):
            s_list.append(self.battery_charge[i])
            s_list.append(self.postponed[i])
            s_list.append(demands[i])
            s_list.append(prices[i])
            s_list.append(gens[i])
        return np.array(s_list, dtype=np.float32)
    
    @property
    def num_agents(self):
        return self._num_agents
