import gymnasium as gym
from gymnasium import spaces
import numpy as np

MAX_STORAGE = 2

class baseEnv(gym.Env):
    """
    A foundational single-agent environment with a stateless reward helper.
    """

    def __init__(self, data, nu_val=0, lambda_val=0, episode_length=80):
        super(baseEnv, self).__init__()

        # data: shape (3, T) => [demand, generation, price]
        self.data = data
        self.num_samples = data.shape[1]
        self.peakDemand = self.data[0, :].max()
        self.peakPrice = self.data[2, :].max()

        self.nu_val = nu_val
        self.lambda_val = lambda_val
        self.episode_length = episode_length

        # Some single-agent internal states
        self.current_step = 0
        self.batteryCharge = 0.0
        self.postponed = 0.0

        # Single-agent state: [batteryCharge, currentDemand, price, postponed]
        self.state_space = spaces.Box(
            low=np.array([0, 0, 0, -self.peakDemand], dtype=float),
            high=np.array([MAX_STORAGE, self.peakDemand, self.peakPrice, self.peakDemand], dtype=float),
            dtype=np.float32
        )
        # Single-agent action: [consumed_grid, consumed_battery]
        self.action_space = spaces.Box(
            low=np.array([0, 0], dtype=float),
            high=np.array([self.peakDemand, MAX_STORAGE], dtype=float),
            dtype=np.float32
        )

        # Simple constraint threshold
        self.thresh = 1
        self.constraints = np.array([self.thresh * self.peakDemand])

        self.observation_space = self.state_space

    def reset(self, seed=42):
        #np.random.seed(seed)

        self.current_step = np.random.randint(0, self.num_samples)
        self.batteryCharge = 0.0
        self.postponed = 0.0

        # For single-agent usage, we pick the data from 'self.current_step'
        demand, genEn, price = self.data[:, self.current_step]
        self.state = np.array([self.batteryCharge, demand, price, self.postponed], dtype=float)

        return self.state, {}

    def step(self, action):
        """
        Single-agent step logic:
          - We call our stateless helper to compute the new battery, postponed, reward
          - Then we rebuild self.state using the next time-step's data
        """
        self.current_step += 1
        if self.current_step >= self.num_samples:
            self.current_step = 0

        # Extract current state's demand, price, and generation
        old_demand = self.state[1]
        old_price  = self.state[2]
        # If you want to pass the "old" generation to reward, do:
        old_genEn  = self.data[1, self.current_step]

        # Compute reward & new battery/postponed
        rew, newBatt, newPostp, consumed_battery, consumed_grid = self.reward(
            batteryCharge=self.batteryCharge,
            demand=old_demand,
            price=old_price,
            genEn=old_genEn,
            action=action,
            nu=self.nu_val
        )

        # compute constraint satisfaction
        constSatisfaction = self.constraints - consumed_grid

        info = {}
        info["consumed_battery"] = consumed_battery
        info["constSatisfaction"] = constSatisfaction

        # Update internal states
        self.batteryCharge = newBatt
        self.postponed = newPostp

        # Now build the next state using the "new" time index
        next_demand, next_genEn, next_price = self.data[:, self.current_step]
        self.state = np.array([self.batteryCharge, next_demand, next_price, self.postponed], dtype=float)

        terminated = (self.current_step >= self.episode_length)
        truncated = False

        return self.state, rew, terminated, truncated, info

    def reward(self, batteryCharge, demand, price, genEn, action, nu):
        """
        Stateless helper for reward & battery update.
          - batteryCharge: current battery
          - demand: current demand
          - price: current grid price
          - genEn: current generation
          - action: [consumed_grid, consumed_battery]
        Returns (reward, newBattery, newPostponed).
        """
        consumed_grid = action[0]
        battery_req   = action[1]
        
        # Actual battery consumption is limited by what's available
        consumed_battery = np.minimum(batteryCharge, battery_req)
        postponed = demand - consumed_battery - consumed_grid

        # Update battery
        newBattery = batteryCharge - consumed_battery
        # Also incorporate charging from generation
        newBattery = np.minimum(MAX_STORAGE, newBattery + self.chargeBattery(genEn))

        # Cost for grid usage
        cost = consumed_grid * price
        rew0 = - cost / (self.peakDemand * self.peakPrice)
        
        # Global constraint (for single constraint, we can do):
        global_rew = -consumed_grid / self.peakDemand

        # Postponed penalty or reward (nu_val)
        post_rew = (nu * postponed) / self.peakDemand
        # Lambda multiplier for global constraint
        lam_rew  = self.lambda_val * global_rew

        augmentedReward = rew0 + post_rew + lam_rew
        return augmentedReward, newBattery, postponed, consumed_battery, consumed_grid

    def chargeBattery(self, genEn, time_length=1, charge_efficiency=0.5):
        """
        Logic to convert generation into battery charge.
        You can adapt the formula as needed.
        """
        # Example: 10 kW installed capacity (hardcoded).
        # genEn is some fractional or percentage representation in your data.
        return charge_efficiency * (10.0 * genEn * time_length) / 1000.0



class DerEnv(baseEnv):
    """
    Single-agent environment that extends baseEnv to include
    Lagrange multipliers in the observation.
    """

    def __init__(self, data, nu_val=0, lambda_val=0):
        super(DerEnv, self).__init__(data, nu_val, lambda_val)

        self.max_lamb = 15
        self.max_nu   = 25

        # Adjust observation to include [lambda_val, nu_val].
        # Original state: 4 dims => [battery, demand, price, postponed]
        # Now we append up to len(self.lambda_val) and 1 for nu_val.
        # In your code, len(self.constraints) might be > 1, so adapt as needed.
        self.observation_space = spaces.Box(
            low=np.array([0, 0, 0, -self.peakDemand, 0, -self.max_nu], dtype=float),
            high=np.array([MAX_STORAGE, 
                           self.peakDemand,
                           self.peakPrice,
                           self.peakDemand,
                           self.max_lamb,
                           self.max_nu], dtype=float),
            dtype=np.float32
        )

    def reset(self, seed=42):
        # Call parent's reset to set up state
        _, info = super().reset(seed=seed)

        # Randomize multipliers
        self.lambda_val = np.random.uniform(
            low=0.0, 
            high=self.max_lamb, 
            size=len(self.constraints)
        )
        self.nu_val = np.random.uniform(
            low=-self.max_nu, 
            high=self.max_nu, 
            size=(1,)
        )

        # Return the augmented observation
        return self._get_obs(), info

    def step(self, action):
        # Parent step => returns (new_state, rew, terminated, truncated, info)
        _, rew, terminated, truncated, info = super().step(action)
        return self._get_obs(), rew, terminated, truncated, info

    def _get_obs(self):
        """
        Return parent's 4D state plus the multipliers [lambda_val, nu_val].
        """
        return np.concatenate([self.state, self.lambda_val, self.nu_val.flatten()])



class MultiAgentDerEnv(DerEnv):
    """
    Multi-agent environment that extends DerEnv for 
    centralized training with multiple agents.
    """

    def __init__(self, data, num_agents=2, nu_val=0, lambda_val=0):
        """
        data: shape (num_agents, 3, T)
              data[i, 0, :] = demand for agent i
              data[i, 1, :] = generation for agent i
              data[i, 2, :] = price for agent i
        """
        # We'll pass a dummy single-building data to DerEnv
        # just to set up basic fields. Then store multi_data for actual usage.
        dummy_data = np.zeros((3, data.shape[-1]))
        super(MultiAgentDerEnv, self).__init__(dummy_data, nu_val, lambda_val)

        self.multi_data = data
        self.num_agents = num_agents
        self.num_samples = data.shape[-1]

        # Recompute peakDemand, peakPrice across all agents
        all_demands = data[:, 0, :].flatten()
        all_prices  = data[:, 2, :].flatten()
        self.peakDemand = np.max(all_demands)
        self.peakPrice  = np.max(all_prices)

        # For each agent: batteryCharge & postponed
        self.batteryCharge_agents = np.zeros(num_agents, dtype=float)
        self.postponed_agents = np.zeros(num_agents, dtype=float)

        # Adjust observation space: 
        # Each agent => [battery, demand, price, postponed, nu] => 5 dims
        # Then we append [lambda_val] => total = 5*num_agents + len(lambda)
        agent_low  = np.array([0, 0, 0, -self.peakDemand, -25], dtype=float)
        agent_high = np.array([MAX_STORAGE, self.peakDemand, self.peakPrice, self.peakDemand, 25], 
                              dtype=float)

        self.num_indiv_obs = len(agent_low)
        obs_low  = np.tile(agent_low,  num_agents)
        obs_high = np.tile(agent_high, num_agents)

        # DerEnv might allow multiple constraints => len(self.constraints)
        # but here we assume 1 constraint => shape(1,) for self.lambda_val
        lamb_low  = np.zeros(len(self.constraints), dtype=float)
        lamb_high = 15.0 * np.ones(len(self.constraints), dtype=float)  # self.max_lamb
        

        self.observation_space = spaces.Box(
            low=np.concatenate([obs_low, lamb_low]),
            high=np.concatenate([obs_high, lamb_high]),
            dtype=np.float32
        )

        # Action space: 2 per agent => total = 2*num_agents
        single_action_low  = np.array([0, 0], dtype=float)
        single_action_high = np.array([self.peakDemand, MAX_STORAGE], dtype=float)

        
        self.num_indiv_act = len(single_action_low)

        self.action_space = spaces.Box(
            low = np.tile(single_action_low,  num_agents),
            high= np.tile(single_action_high, num_agents),
            dtype=np.float32
        )

    def reset(self, seed=42):
        """
        Reset environment for multi-agent usage.

        # For single-agent usage, we pick the data from 'self.current_step'
        demand, genEn, price = self.data[:, self.current_step] ---> in get obs
        self.state = np.array([self.batteryCharge, demand, price, self.postponed], dtype=float)

        return self.state, {}
        """
        #np.random.seed(seed)
        self.current_step = np.random.randint(0, self.num_samples)

        # For each agent, reset battery & postponed
        self.batteryCharge_agents[:] = 0.0
        self.postponed_agents[:]     = 0.0

        # Randomize multipliers again if you want (or leave parent's logic)
        self.lambda_val = np.random.uniform(
            low=0.0,
            high=15.0,
            size=len(self.constraints)
        )
        self.nu_val = np.random.uniform(
            low=-25.0,
            high=25.0,
            size=(self.num_agents,)
        )

        return self._get_obs(), {}

    def step(self, action):
        """
        action: shape (2*num_agents,) => parse each agent's 2 actions,
        call the parent's compute_reward_and_update(...) for each agent.
        """
        total_reward = 0.0
        total_consumed_grid = 0.0
        total_consumed_battery = 0.0

        # Move to the next time step
        self.current_step += 1
        if self.current_step >= self.num_samples:
            # if you want an episode to end, or wrap around
            self.current_step = 0

        for i in range(self.num_agents):
            # parse 2 actions for agent i
            agent_action = action[2*i : 2*i + 2]

            # demand, genEn, price for agent i at current_step
            demand_i = self.multi_data[i, 0, self.current_step]
            genEn_i  = self.multi_data[i, 1, self.current_step]
            price_i  = self.multi_data[i, 2, self.current_step]

            # Use parent's stateless helper
            rew_i, newBatt, newPostp, consumed_battery, consumed_grid  = self.reward(
                batteryCharge=self.batteryCharge_agents[i],
                demand=demand_i,
                price=price_i,
                genEn=genEn_i,
                action=agent_action,
                nu=self.nu_val[i]
            )
            total_reward += rew_i
            total_consumed_grid += consumed_grid
            total_consumed_battery += consumed_battery

            # Store new battery & postponed
            self.batteryCharge_agents[i] = newBatt
            self.postponed_agents[i]     = newPostp

        tot_info = {}
        tot_info["globalConstSatisfaction"] = self.constraints - total_consumed_grid
        tot_info["total_consumed_battery"] = total_consumed_battery
        

        terminated = (self.current_step >= self.episode_length)
        truncated = False

        return self._get_obs(), total_reward, terminated, truncated, tot_info

    def _get_obs(self): # might be the reason for postponed problem. FIX CORRECTLY
        """
        Build a single global observation that includes:
        [battery_i, demand_i, price_i, postponed_i, nu_i] for each agent i
        plus a single global lambda at the end.
        """
        obs_list = []

        # Current step's data for each agent
        for i in range(self.num_agents):
            battery_i = self.batteryCharge_agents[i]
            postponed_i = self.postponed_agents[i]
            demand_i = self.multi_data[i, 0, self.current_step]
            price_i  = self.multi_data[i, 2, self.current_step]

            # Append [battery, demand, price, postponed, nu_i]
            obs_list.append(battery_i)
            obs_list.append(demand_i)
            obs_list.append(price_i)
            obs_list.append(postponed_i)
            obs_list.append(self.nu_val[i])
        
        # Single global lambda
        obs_list.extend(self.lambda_val)
        return np.array(obs_list, dtype=np.float32)
