import torch
from torch import Tensor

from vmas.simulator.core import Agent, World, Sphere
from vmas.simulator.scenario import BaseScenario
from vmas.simulator.environment import Environment
from vmas.simulator.utils import Color
from vmas.simulator.utils import ScenarioUtils 
import numpy as np
import uuid
import csv
from pathlib import Path
import os
import datetime    


############################
# Constants / Hyperparams
############################
MAX_STORAGE = 2.0
CHARGE_EFFICIENCY = 0.5
EPISODE_LENGTH = 80
PDM = -8  # Postponed Demand Multiplier (nu)
DRM = 8  # Demand / Resource global constraint multiplier (lambda)

    
class SmartGridScenario(BaseScenario):
    """
    Scenario for multi-agent demand-response in a simplified smart grid setting.

    The key changes:
      - We track BOTH a `start_index` (where in data we begin)
      - And a relative `current_step` from 0..episode_length.
      - We then cycle the data index as: (start_index + current_step) mod num_samples.
    """
    def __init__(self, eval=False):
        super().__init__()
        self.uid = uuid.uuid4().hex[:4]   # random 4‑char id
        self.eval = eval
        self._log_writers = {}
        self.rms_cost   = RunningMeanStd()
        self.rms_grid   = RunningMeanStd()
        self.rms_backlg = RunningMeanStd()
        self.R_CLIP     = 50.0      # optional final clamp

    # helper to load demand/generation/price traces
    def _load_data(self, building_types):
        return np.stack([
            np.load(f"./data/b{bt}/data.npy").astype(np.float32)
            for bt in building_types
        ])

    def make_world(self, batch_dim: int, device: torch.device, **kwargs):
        """
        Create and return a multi-agent World object.
        # """ 
        self.episode_length  = kwargs.pop("episode_length", EPISODE_LENGTH)
        self.verbose         = kwargs.pop("verbose", False)
        self.building_types  = kwargs.pop("building_types", [1, 2, 3])

        # Tell VMAS we have taken everything we care about
        ScenarioUtils.check_kwargs_consumed(kwargs)   # ← without this I get the warning
        # --------------------------------------------------------------

        self.n_agents = len(self.building_types)
        self.device   = device
        self.data     = self._load_data(self.building_types)
        self.num_samples = self.data.shape[-1]

        world = World(batch_dim, device, substeps=1, drag=0.0)
        self._world = world

        # 1) We track a separate "start index" for each (env, agent)
        #    shape = [batch_dim, n_agents]
        self.start_index = torch.zeros(
            batch_dim, self.n_agents, device=device, dtype=torch.long
        )
        # 2) We track a "current_step" for each environment (relative to start_index)
        self.current_step = torch.zeros(batch_dim, device=device, dtype=torch.long)

        # Create agents
        for i in range(self.n_agents):
            agent = Agent(
                name=f"agent_{i}",
                u_range=1.0,  
            )
            agent.state.battery_charge = torch.zeros(batch_dim, device=device)
            agent.state.postponed = torch.zeros(batch_dim, device=device)
            agent.state.demand = torch.zeros(batch_dim, device=device)
            agent.state.generation = torch.zeros(batch_dim, device=device)
            agent.state.price = torch.zeros(batch_dim, device=device)
            world.add_agent(agent)


        # Some normalizing values
        self.peak_demand = np.max(self.data[:,0,:])
        self.peak_price = np.max(self.data[:,2,:])
        self.peak_demand_torch = torch.tensor(self.peak_demand, device=device, dtype=torch.float)
        self.peak_price_torch = torch.tensor(self.peak_price, device=device, dtype=torch.float)

        return world

    def reset_world_at(self, env_index: int = 0):
        """
        Reset environment for a single environment index in batch.

        We pick a random start_index in [0..num_samples-1], 
        and set current_step to zero. 
        That means the first data index = (start_index + 0) % num_samples.
        """
        # Randomly pick a start index *per agent* in this env
        s0 = torch.randint(
            low=0, high=self.num_samples, size=(self.n_agents,), device=self.device
        )
        self.start_index[env_index, :] = s0

        # Relative step is zero each new episode
        self.current_step[env_index] = 0

        # Reset agent states
        for i, ag in enumerate(self.world.agents):
            # reset scalar states
            ag.state.battery_charge[env_index] = 0.0
            ag.state.postponed[env_index] = 0.0
            # for agent i, pick only its own start index
            start_idx = int(s0[i].item())
            # load the scalar from your numpy data
            ag.state.demand[env_index]     = torch.tensor(
                self.data[i, 0, start_idx], device=self.device
            )
            ag.state.generation[env_index] = torch.tensor(
                self.data[i, 1, start_idx], device=self.device
            )
            ag.state.price[env_index]      = torch.tensor(
                self.data[i, 2, start_idx], device=self.device
            )

    def reward(self, agent):
        """
        Cost term + DRM‑scaled grid penalty + PDM‑scaled backlog term.
        Each raw signal is whitened online to keep magnitudes comparable.
        """
        # ─── actions → physical quantities ──────────────────────────
        a = agent.action.u                                  # (batch, 2)  in [-1,1]
        consumed_grid    = a[:, 0].clamp(min=0) * self.peak_demand_torch
        consumed_battery = a[:, 1].clamp(min=0) * MAX_STORAGE

        # ─── battery bookkeeping & backlog ──────────────────────────
        old_batt = agent.state.battery_charge
        batt_used = torch.minimum(old_batt, consumed_battery)

        total_demand = agent.state.demand + agent.state.postponed
        backlog = total_demand - batt_used - consumed_grid           # ± values
        agent.state.postponed = backlog

        new_batt = old_batt - batt_used
        new_batt += CHARGE_EFFICIENCY * (agent.state.generation / 1000.0)
        agent.state.battery_charge = torch.clamp(new_batt, 0.0, MAX_STORAGE)

        # ─── raw (unnormalised) signals ─────────────────────────────
        cost_kw      = consumed_grid * agent.state.price               # € per step
        grid_frac    = consumed_grid #/ self.peak_demand_torch          # ∈ [0, 1+]
        backlog_frac = backlog        #/ self.peak_demand_torch         # around 0

        # ─── update running statistics (no‑grad) ────────────────────
        with torch.no_grad():
            self.rms_cost.update(cost_kw)
            self.rms_grid.update(grid_frac)
            self.rms_backlg.update(backlog_frac)

        # ─── whiten each component ──────────────────────────────────
        cost_norm   = cost_kw      / self.rms_cost.std           # ~𝒩(0,1)
        grid_norm   = grid_frac    / self.rms_grid.std
        back_norm   = backlog_frac / self.rms_backlg.std

        # ─── build reward pieces, keeping *original* multipliers ────
        rew_cost = -cost_kw                                  # lower € ⇒ higher R
        lam_rew  = DRM * (-grid_frac)      # un‑normalised
        post_rew = PDM * backlog_frac      # un‑normalised
                            # ± backlog term

        augmented = rew_cost + lam_rew + post_rew

        # optional clamp so critic targets never explode
        augmented = torch.clamp(augmented, -self.R_CLIP, self.R_CLIP)

        return augmented.unsqueeze(-1)                         # shape (batch,1)


    def observation(self, agent):
        """
        [battery_charge, current demand, price, postponed demand]
        """

        obs_list = [
            agent.state.battery_charge / MAX_STORAGE,
            agent.state.demand / self.peak_demand_torch,
            agent.state.price / self.peak_price_torch,
            agent.state.postponed / self.peak_demand_torch,
        ]
        obs = torch.stack(obs_list, dim=1)
        return obs

    def done(self):
        """
        Return a boolean Tensor (batch_dim,) for whether each environment is done.

        We consider 'done' if current_step >= episode_length, i.e. the episode is over.
        """
        if self.verbose:
            print("episode_length:", self.episode_length, "|",
              "start_index:", self.start_index, "|",
              "current_step:", self.current_step)

        done_mask = (self.current_step >= self.episode_length)
        return done_mask

    def info(self, agent: Agent):
        """
        Return per-agent, per-env metrics. BenchMarl will aggregate across agents.
        Each tensor is shape (batch_dim, 1).
        """
        # battery charge
        battery   = agent.state.battery_charge.unsqueeze(-1)  # (batch,1)

        # grid & battery consumption
        if agent.action.u is None:
            batch_dim = battery.shape[0]
            grid_cons = torch.zeros(batch_dim, 1, device=self.device)
            batt_cons = torch.zeros(batch_dim, 1, device=self.device)
        else:
            frac      = agent.action.u.clamp(min=0)             # (batch,2)
            grid_cons = (frac[:, :1] * self.peak_demand_torch)   # (batch,1)
            batt_cons = (frac[:, 1:] * MAX_STORAGE)

        # postponed demand
        postponed = agent.state.postponed.unsqueeze(-1)        # (batch,1)

        # cost
        cost      = (grid_cons * agent.state.price.unsqueeze(-1))

        # **Do not prefix** the keys—leave them as-is:
        return {
            "battery":    battery,
            "grid_cons":  grid_cons,
            "batt_cons":  batt_cons,
            "postponed":  postponed,
            "cost":       cost,
        }


    def post_step(self):
        """
        Called once after each env.step().
        1. Refresh demand / generation / price for every (env, agent).
        2. Log per-agent actions to CSV (when self.eval is True).
        3. Bump current_step.
        """
        # ─── 1. Update state tensors ────────────────────────────────
        for agent_i, agent in enumerate(self.world.agents):
            idx_i = (self.start_index[:, agent_i] + self.current_step) % self.num_samples
            agent.state.demand     = torch.tensor(self.data[agent_i, 0], device=self.device)[idx_i]
            agent.state.generation = torch.tensor(self.data[agent_i, 1], device=self.device)[idx_i]
            agent.state.price      = torch.tensor(self.data[agent_i, 2], device=self.device)[idx_i]

        # ─── 2. Per-agent CSV logging (only in eval mode) ───────────
        if self.eval:
            log_dir = Path(os.environ["SMARTGRID_LOG_DIR"])
            log_dir.mkdir(parents=True, exist_ok=True)

            step_idx = int(self.current_step[0])  # label for *this* row

            for agent_i, agent in enumerate(self.world.agents):
                # open / cache writer
                if agent_i not in self._log_writers:
                    fp = log_dir / f"agent_{agent_i}_actions.csv"
                    fh = fp.open("w", newline="")
                    w  = csv.DictWriter(
                        fh,
                        fieldnames=["env_id", "step", "a_grid", "a_batt", "post", "price", "demand"],
                    )
                    w.writeheader()
                    self._log_writers[agent_i] = (w, fh)

                writer, _ = self._log_writers[agent_i]

                # compute consumptions from the *executed* action
                u = agent.action.u              # (batch, 2)
                grid_cons = u[:, 0].clamp(min=0) * self.peak_demand_torch
                batt_cons = u[:, 1].clamp(min=0) * MAX_STORAGE
                postponed = agent.state.postponed                # after reward update
                price      = agent.state.price

                for env_id in range(u.shape[0]):
                    writer.writerow({
                        "env_id": env_id,
                        "step":   step_idx,
                        "a_grid": float(grid_cons[env_id]),
                        "a_batt": float(batt_cons[env_id]),
                        "post":   float(postponed[env_id]),
                        "price":   float(price[env_id]),
                        "demand": float(agent.state.demand[env_id]),
                    })

        # ─── 3. Advance to next step ────────────────────────────────
        self.current_step += 1

        
    def supports_continuous_actions(self):
        return True

    def supports_discrete_actions(self):
        return False

    def __del__(self):
    # close any open CSV handles
        for _, (_, fhandle) in getattr(self, "_log_writers", {}).items():
            try:
                fhandle.close()
            except Exception:
                pass

import torch


class RunningMeanStd:
    """
    Keeps a running (exponential‑time) mean and variance of incoming tensors.
    Works on any device: the first `.update()` call sets the device.
    """

    def __init__(self, epsilon: float = 1e-4, shape=()):
        self.mean   = torch.zeros(shape)
        self.var    = torch.ones(shape)
        self.count  = torch.tensor(epsilon)   # helps at start‑up to avoid 0‑div
        self.device_set = False

    # ───────────────────────────── public API ──────────────────────────────
    def update(self, x: torch.Tensor) -> None:
        """
        Incorporate a new batch of samples `x` (any leading shape is fine).
        """
        if not self.device_set:       # move running stats to x's device once
            self.mean   = self.mean.to(x.device)
            self.var    = self.var.to(x.device)
            self.count  = self.count.to(x.device)
            self.device_set = True

        batch_mean  = x.mean(dim=0)
        batch_var   = x.var(dim=0, unbiased=False)
        batch_count = x.shape[0]      # assumes batch dimension is 0
        self._update_from_moments(batch_mean, batch_var, batch_count)

    @property
    def std(self) -> torch.Tensor:
        """ Current running standard deviation (adds a small ε for stability). """
        return torch.sqrt(self.var + 1e-6)

    # ─────────────────────────── internal helpers ──────────────────────────
    @torch.no_grad()
    def _update_from_moments(self,
                             batch_mean: torch.Tensor,
                             batch_var: torch.Tensor,
                             batch_count: int) -> None:
        delta      = batch_mean - self.mean
        tot_count  = self.count + batch_count

        new_mean   = self.mean + delta * batch_count / tot_count

        m_a = self.var * self.count
        m_b = batch_var * batch_count
        M2  = m_a + m_b + delta.pow(2) * self.count * batch_count / tot_count
        new_var = M2 / tot_count

        # commit
        self.mean  = new_mean
        self.var   = new_var
        self.count = tot_count
