import numpy as np
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
import matplotlib.pyplot as plt
from src.env import DerEnv, baseEnv, MultiAgentDerEnv


class ACRLAgent:

    def __init__(self, verb=0, currentBuilding=1, derenv=True, nu_val=0, lambda_val=0):
        data = np.load(f'./data/b{currentBuilding}/data.npy')
        self.derenv = derenv
        self.neighborhoodLambds = 0
        if derenv:
            self.env = DerEnv(data)
        else:
            self.env = baseEnv(data, nu_val, lambda_val)

        self.model = PPO(policy="MlpPolicy", env=self.env, gamma=1, learning_rate=0.0003, use_sde=True, ent_coef=0.01, verbose=verb, seed=123) # SB3 RL algo                  
    
    # Implement training logic
    def train(self, T0=20, total_timesteps = 1_000_000, model_name="ppo_acrl"): 
        self.env.episode_length = T0      
        self.model.learn(total_timesteps=total_timesteps)
        self.model.save("./models/" + model_name)  # saving the model to ppo_acrl.zip    
        

class ACRLMultiAgent:
    def __init__(self, 
                 verb=0, 
                 building_types=None, 
                 nu_val=0, 
                 lambda_val=0, 
                 n_agents=2):
        """
        Multi-agent version of ACRLAgent for centralized training using MultiAgentDerEnv.

        Args:
          verb (int): Verbosity level for PPO logs.
          building_types (list): List of building types for each agent (e.g. [1,2,1]).
          nu_val (float): Initial value for local constraint multiplier.
          lambda_val (float): Initial value for global constraint multiplier.
          n_agents (int): Number of agents in the environment.
        """
        if building_types is None:
            building_types = [1] * n_agents  # Default to building type 1 if not provided

        if len(building_types) != n_agents:
            raise ValueError("Length of building_types must match the number of agents (n_agents).")

        # ----------------------------------------------------------------------
        # 1. Load the data for each building and store it in a list
        #    Each building_data is shape (3, T)
        # ----------------------------------------------------------------------
        building_data_list = []
        for btype in building_types:
            # Load building i
            building_data = np.load(f'./data/b{btype}/data.npy')  # shape => (3, T)
            building_data_list.append(building_data)

        # ----------------------------------------------------------------------
        # 2. Stack them into shape (n_agents, 3, T) 
        # ----------------------------------------------------------------------
        # So agent i sees building_data_arr[i, 0, t] as demand at time t,
        # agent i sees building_data_arr[i, 1, t] as generation at time t,
        # agent i sees building_data_arr[i, 2, t] as price at time t.
        building_data_arr = np.stack(building_data_list, axis=0)  # shape => (n_agents, 3, T)

        # ----------------------------------------------------------------------
        # 3. Create the multi-agent environment
        # ----------------------------------------------------------------------
        self.env = MultiAgentDerEnv(
            data=building_data_arr, 
            num_agents=n_agents, 
            nu_val=nu_val, 
            lambda_val=lambda_val
        )

        # ----------------------------------------------------------------------
        # 4. Initialize PPO with that environment
        # ----------------------------------------------------------------------
        self.model = PPO(
            policy="MlpPolicy",
            env=self.env,
            gamma=1.0,
            learning_rate=0.0003,
            use_sde=True,
            ent_coef=0.01,
            verbose=verb,
            seed=123
        )

    def train(self, T0=20, total_timesteps=1_000_000, model_name="ppo_acrl_multi"):
        """
        Train the multi-agent model for total_timesteps.

        Args:
          T0 (int): Maximum steps per episode in the environment
          total_timesteps (int): how many timesteps to train in total
          model_name (str): name for saving the trained model
        """
        # Optionally, you can set env.max_steps = T0 if you want each episode to last at most T0 steps
        self.env.episode_length = T0
        
        # Learn and save the model
        self.model.learn(total_timesteps=total_timesteps)
        self.model.save("./models/" + model_name)  # => ppo_acrl_multi.zip


    

        




           