import torch
from torch import Tensor

from vmas.simulator.core import Agent, World, Sphere
from vmas.simulator.scenario import BaseScenario
from vmas.simulator.environment import Environment
from vmas.simulator.utils import Color
import numpy as np

############################
# Constants / Hyperparams
############################
MAX_STORAGE = 2.0
CHARGE_EFFICIENCY = 0.5
EPISODE_LENGTH = 80
PDM = 20  # Postponed Demand Multiplier
DRM = 20  # Demand / Resource global constraint multiplier

class PatchedEnv(Environment):
    def step(self, actions):
        obs, rewards, done, info = super().step(actions)
        obs = obs.to(torch.float32)

        # Manually increment current_step in the scenario
        if hasattr(self.scenario, "current_step"):
            self.scenario.current_step += 1  # <- this works

        return obs, rewards, done, info
    
class SmartGridScenario(BaseScenario):
    """
    Scenario for multi-agent demand-response in a simplified smart grid setting.

    The key changes:
      - We track BOTH a `start_index` (where in data we begin)
      - And a relative `current_step` from 0..episode_length.
      - We then cycle the data index as: (start_index + current_step) mod num_samples.
    """

    def make_world(self, batch_dim: int, device: torch.device, **kwargs):
        """
        Create and return a multi-agent World object.
        """
        self.n_agents = kwargs.get("n_agents", 3)
        self.episode_length = kwargs.get("episode_length", EPISODE_LENGTH)
        self.verbose = kwargs.get("verbose", False)
        
        # Load or create data
        data_path = kwargs.get("data_path", None)
        if data_path is not None:
            self.data = np.load(data_path).astype(np.float32)  # shape: (3, T)
        else:
            # Fallback synthetic data
            T = 1000
            demand = np.random.rand(T) * 2.0
            generation = np.random.rand(T) * 1.0
            price = np.random.rand(T) * 1.0
            self.data = np.stack([demand, generation, price], axis=0)
        self.num_samples = self.data.shape[1]

        # Create the world
        world = World(batch_dim, device, substeps=1, drag=0.0)
        self._world = world

        # Create agents
        for i in range(self.n_agents):
            agent = Agent(
                name=f"agent_{i}",
                shape=Sphere(radius=0.03),
                collide=False,
                color=Color.BLUE,
            )
            agent.state.battery_charge = torch.zeros(batch_dim, device=device)
            agent.state.postponed = torch.zeros(batch_dim, device=device)
            world.add_agent(agent)

        # 1) We track a "start index" for each environment in the batch
        self.start_index = torch.zeros(batch_dim, device=device, dtype=torch.long)
        # 2) We track a "current_step" for each environment (relative to start_index)
        self.current_step = torch.zeros(batch_dim, device=device, dtype=torch.long)

        # Some normalizing values
        self.peak_demand = np.max(self.data[0])
        self.peak_price = np.max(self.data[2])
        self.peak_demand_torch = torch.tensor(self.peak_demand, device=device, dtype=torch.float)
        self.peak_price_torch = torch.tensor(self.peak_price, device=device, dtype=torch.float)

        return world

    def env_make_world(self, batch_dim, device, **kwargs):
        return self.make_world(batch_dim, device, **kwargs)

    def reset_world_at(self, env_index: int = 0):
        """
        Reset environment for a single environment index in batch.

        We pick a random start_index in [0..num_samples-1], 
        and set current_step to zero. 
        That means the first data index = (start_index + 0) % num_samples.
        """
        # Randomly pick the start index
        s0 = torch.randint(low=0, high=self.num_samples, size=(1,))
        self.start_index[env_index] = s0
        # Relative step is zero each new episode
        self.current_step[env_index] = 0

        # Reset agent states
        for ag in self.world.agents:
            ag.state.battery_charge[env_index] = 0.0
            ag.state.postponed[env_index] = 0.0

    def reward(self, agent):
        """
        Compute agent-specific reward. 
        """
        action_tensor = agent.action.u  # shape (batch_dim, action_dim)
        consumed_grid    = action_tensor[:, 0].clamp(min=0)
        consumed_battery = action_tensor[:, 1].clamp(min=0)

        # Index used for data is (start_index + current_step) mod num_samples
        idx = (self.start_index + self.current_step) % self.num_samples

        # demand, gen, price
        demand_batched = torch.tensor(self.data[0], device=self.world.device)[idx]
        generation_batched = torch.tensor(self.data[1], device=self.world.device)[idx]
        price_batched  = torch.tensor(self.data[2], device=self.world.device)[idx]

        # Old battery
        old_battery = agent.state.battery_charge
        actual_batt = torch.minimum(old_battery, consumed_battery)

        # Postponed
        total_demand = demand_batched + agent.state.postponed
        postponed = total_demand - actual_batt - consumed_grid

        # Update battery
        new_battery = old_battery - actual_batt
        new_battery += CHARGE_EFFICIENCY * (generation_batched / 1000.0)
        new_battery = torch.clamp(new_battery, 0.0, MAX_STORAGE)

        agent.state.battery_charge = new_battery
        agent.state.postponed      = torch.clamp(postponed, min=0.0)

        # Cost for grid usage
        cost = consumed_grid * price_batched
        rew0 = -cost / (self.peak_demand_torch * self.peak_price_torch)

        # Extra terms
        global_rew = -consumed_grid / self.peak_demand_torch
        post_rew   = (PDM * postponed) / self.peak_demand_torch
        lam_rew    = (DRM * global_rew)

        augmented_reward = rew0 + lam_rew + post_rew
        return augmented_reward

    def observation(self, agent):
        """
        [battery_charge, current demand, price, postponed demand]
        """
        idx = (self.start_index + self.current_step) % self.num_samples

        demand_batched = torch.tensor(self.data[0], device=self.world.device)[idx]
        price_batched  = torch.tensor(self.data[2], device=self.world.device)[idx]

        obs_list = [
            agent.state.battery_charge,
            demand_batched,
            price_batched,
            agent.state.postponed,
        ]
        obs = torch.cat(obs_list, dim=-1)
        return obs

    def done(self):
        """
        Return a boolean Tensor (batch_dim,) for whether each environment is done.

        We consider 'done' if current_step >= episode_length, i.e. the episode is over.
        """
        if self.verbose:
            print("episode_length:", self.episode_length, "|",
              "start_index:", self.start_index, "|",
              "current_step:", self.current_step)

        done_mask = (self.current_step >= self.episode_length)
        return done_mask

    def info(self, agent):
        """
        Additional logs, if desired
        """
        return {
            "battery": agent.state.battery_charge.clone(),
            "postponed": agent.state.postponed.clone(),
        }

    def step_end(self):
        """
        Called once after all agents act. We increment current_step by 1
        (like a normal time-step).
        """
        self.current_step += 1
        

    def supports_continuous_actions(self):
        return True

    def supports_discrete_actions(self):
        return False

from vmas.simulator.heuristic_policy import BaseHeuristicPolicy

class HeuristicPolicy(BaseHeuristicPolicy):
    def compute_action(self, observation: torch.Tensor, u_range: float) -> torch.Tensor:
        battery_charge = observation[:, 0]
        price = observation[:, 2]
        from_battery = torch.minimum(battery_charge, 0.5 * torch.ones_like(battery_charge))
        from_grid    = torch.where(price > 0.7,
                                   torch.zeros_like(price),
                                   0.5 * torch.ones_like(price))
        return torch.stack([from_grid, from_battery], dim=1)
