import numpy as np
from src.env import DerEnv, baseEnv, MultiAgentDerEnv
import networkx as nx
import matplotlib.pyplot as plt

from stable_baselines3 import PPO

class marlAcrl:

    def __init__(self, models, connections, buildings=[1, 1, 1, 1], derenv=True, nu_val=0, lambda_val=0, thresh=0.27):
        
        #np.random.seed(420)
        self.num_agents = len(buildings)
        if derenv:
            self.agents = [DerEnv(np.load(f'./data/b{buildings[i]}/data.npy')) for i in range(self.num_agents)]
        else:
            self.agents = [baseEnv(np.load(f'./data/b{buildings[i]}/data.npy'), nu_val, lambda_val) for i in range(self.num_agents)]
        self.graph = self.create_graph(self.num_agents, connections)
        self.buildings = buildings
        # Multipliers
        self.nei_lambda_means = np.zeros(self.num_agents)
        self.mus = np.zeros(self.num_agents)

        self.num_samples = self.agents[0].data.shape[1]
        self.num_actions = self.agents[0].action_space.shape[0]
        self.num_states = self.agents[0].observation_space.shape[0]
        self.models = [PPO.load(models[i]) for i in range(self.num_agents)]
        self.set_costraint(thresh=thresh)

   
    def set_costraint(self, thresh=0.5):
        total_demand = np.sum([agent.peakDemand for agent in self.agents])
        self.constraint = np.array([thresh * total_demand])
        for agent in self.agents:
            agent.constraints = self.constraint / self.num_agents

    def create_graph(self, num_agents, connections):
        # Create an empty list for each agent
        graph = [[] for _ in range(num_agents)]
        for (src, dest) in connections:
            graph[src].append(dest)
            graph[dest].append(src)  # Assuming undirected graph
        return graph
    
    def update_consensus(self):
        # Update global lambda consensus for each agent based on its neighbors
        new_lambda_means = np.zeros(self.num_agents)
        for i in range(self.num_agents):
            if len(self.graph[i]) > 0:
                # Calculate the average lambda of the neighbors
                neighbor_lambdas_means = np.mean([self.agents[n].lambda_val for n in self.graph[i]])
                new_lambda_means[i] = neighbor_lambdas_means
            else:
                new_lambda_means[i] = self.agents[i].lambda_val
        self.nei_lambda_means = new_lambda_means

    def step_all(self, obs, lamdaStep, muStep, nuStep, consensus=True, consensus_steps=1):
        # Step through all agents
        rew = np.zeros(self.num_agents)
        actions = np.zeros((self.num_agents, self.num_actions))
        lambdas = np.zeros(self.num_agents)
        nus = np.zeros(self.num_agents)
        for i, agent in enumerate(self.agents):
            # Take action and get new obs
            actions[i], _ = self.models[i].predict(obs[i], deterministic=True)  
            obs[i], rew[i], _,_ , info = agent.step(actions[i])
            consumed_battery = info['consumed_battery']
            constSatisfaction = info['constSatisfaction']
            actions[i] = np.array([actions[i, 0], consumed_battery])
            # Compute consensus multiplier
            self.mus[i] += consensus_steps * (muStep) * (agent.lambda_val - self.nei_lambda_means[i])
            self.mus[i] = np.clip(self.mus[i], -20, 20)
            # Compute global multiplier 
            cons = self.mus[i] if consensus else 0
            agent.lambda_val -= consensus_steps * (lamdaStep) * (constSatisfaction + cons)
            agent.lambda_val = np.clip(agent.lambda_val, 0, agent.max_lamb)

            lambdas[i] = agent.lambda_val
            # Update local multiplier
            agent.nu_val -= (nuStep) * agent.postponed * (agent.peakDemand) 
            agent.nu_val = np.clip(agent.nu_val, -agent.max_nu, agent.max_nu)
            nus[i] = agent.nu_val

        self.update_consensus()
        return obs, rew, actions, lambdas, nus
    
    def step_all_no_consensus(self, obs):
        # Step through all agents
        rew = np.zeros(self.num_agents)
        actions = np.zeros((self.num_agents, self.num_actions))
        for i, agent in enumerate(self.agents):
            # Take action and get new obs
            actions[i], _ = self.models[i].predict(obs[i], deterministic=False)  
            obs[i], rew[i], _,_, info = agent.step(actions[i])
            consumed_battery = info['consumed_battery']
            actions[i] = np.array([actions[i, 0], consumed_battery])

        return obs, rew, actions

    def run_episode(self, epochs=100,lamdaStep=0.01, muStep=0.01, nuStep=0.01, T0=1, consensus=True, consensus_steps=1, from_start=False):
        
        obs = np.zeros((self.num_agents, self.num_states)) 
        self.mus = np.zeros(self.num_agents)
        self.nei_lambda_means = np.zeros(self.num_agents)
        for i, agent in enumerate(self.agents):
            obs[i], _  = agent.reset()
            agent.lambda_val = np.zeros(1)
            agent.nu_val = np.zeros(1)
            if from_start:
                agent.current_step = 0

        self.obs_history = np.zeros((epochs, self.num_agents, self.num_states))
        self.actions_list = np.zeros((epochs, self.num_agents, self.num_actions))
        self.lambdas_list = np.zeros((epochs, self.num_agents))
        self.nus_list = np.zeros((epochs, self.num_agents))
        self.mus_list = np.zeros((epochs, self.num_agents))
        self.neighboor_lambda_means = np.zeros((epochs, self.num_agents))
        
        for j in range(1, epochs * T0):
            obs, rew, actions, lambdas, nus = self.step_all(obs,lamdaStep, muStep, nuStep, consensus=consensus, consensus_steps=consensus_steps)
            self.obs_history[j] = obs
            self.actions_list[j] = actions
            self.lambdas_list[j] = lambdas
            self.nus_list[j] = nus 
            self.mus_list[j] = self.mus
            self.neighboor_lambda_means[j] = self.nei_lambda_means

    def run_episode_no_consensus(self, epochs=100):
        obs = np.zeros((self.num_agents, self.num_states)) 
        for i, agent in enumerate(self.agents):
            obs[i], _  = agent.reset()

        self.obs_history = np.zeros((epochs, self.num_agents, self.num_states))
        self.actions_list = np.zeros((epochs, self.num_agents, self.num_actions))
        
        for j in range(1, epochs):
            obs, rew, actions = self.step_all_no_consensus(obs)
            self.obs_history[j] = obs
            self.actions_list[j] = actions


    def compare_algs(self, epochs=100, consensus_steps=1, from_start=False, cum=True):

        def runnin(arr):
            return np.cumsum(arr)/np.arange(1, len(arr) + 1)
        
        def cumulative(arr):
            return np.cumsum(runnin(arr)) if cum else runnin(arr)
        
        self.run_episode(epochs=epochs, consensus=True, consensus_steps=consensus_steps, from_start=from_start)
        cost_consenso = runnin(self.actions_list[..., 0].sum(1) * self.obs_history[:, 0, 2])/ 100
        unmet_cons = cumulative(self.obs_history[..., 1].sum(1) - self.actions_list[..., 1].sum(1) - self.actions_list[..., 0].sum(1))
        self.plot_multipliers(window_size=1, plot_nus=False, plot_mus=False, cons=True)

        self.run_episode(epochs=epochs, consensus=False, from_start=from_start)
        cost_sin_cons = runnin(self.actions_list[..., 0].sum(1) * self.obs_history[:, 0, 2]) / 100
        unmet_sin_cons = cumulative(self.obs_history[..., 1].sum(1) - self.actions_list[..., 1].sum(1) - self.actions_list[..., 0].sum(1))
        self.plot_multipliers(window_size=1, plot_nus=False, plot_mus=False, cons=False)

        cost_sin_marl = runnin(self.obs_history[..., 1].sum(1) * self.obs_history[:, 0, 2]) / 100

        # Create subplots
        fig, ax = plt.subplots(figsize=(12, 6))

        # Plot data
        ax.plot(cost_consenso, label='Cost with Consensus', linewidth=2)
        ax.plot(cost_sin_cons, label='Cost without Consensus', alpha=0.7, linewidth=2)
        ax.plot(cost_sin_marl, '--', label='Cost without control', linewidth=2)

        # Adjust the y-axis limits based on the data range
        ax.set_ylim(bottom=np.max(cost_consenso)/3)

        # Labels and Title with Larger Fonts
        ax.set_title('Cost Comparison', fontsize=22)
        ax.set_ylabel('Cost ($/kWh)', fontsize=22)
        ax.set_xlabel('Epochs', fontsize=22)
        ax.set_xlim(0, len(cost_consenso) - 1)
        
        # Increase legend font size
        ax.legend(fontsize=22)

        # Increase tick font sizes
        ax.tick_params(axis='both', which='major', labelsize=20)

        # Add grid
        ax.grid(True, linestyle='--', alpha=0.6)

        # Save as PDF
        plt.tight_layout()
        plt.savefig("./results/costComp.pdf", format="pdf", bbox_inches="tight")

        plt.show()  # Show plot if running interactively

    def plot_agent_graph(self, connections):
        # Create a new graph
        G = nx.Graph()
        
        # Add edges to the graph
        G.add_edges_from(connections)
        # Scale node sizes based on the 'buildings' values
        node_sizes = [1000 if size==1 else 3000 for size in self.buildings]
        
        # Draw the graph
        plt.figure(figsize=(4.5, 4.5))
        pos = nx.spring_layout(G)  # positions for all nodes
        nx.draw(G, pos, with_labels=False, node_color='skyblue', node_size=node_sizes, edge_color='k', linewidths=1, font_size=5)
        
        # Draw node labels
        labels = {i: f'Agent {i}' for i in G.nodes()}
        nx.draw_networkx_labels(G, pos, labels, font_size=6)
        
        # Show the plot
        #plt.title("Agent Connectivity Graph")
        plt.savefig(f"./results/netConf_A{np.max(connections)+1}_E{len(connections)}.pdf", format="pdf", bbox_inches="tight")
        plt.show()

  
    def plot_multipliers(self, window_size=10, plot_mus=True, plot_nus=True, cons=True, agents_to_plot=3):

        def moving_average(data, window_size):
            return np.convolve(data, np.ones(window_size) / window_size, mode='valid')
        
        samples, num_agents = self.lambdas_list.shape
        time = np.arange(samples)
        rows = 1
        # Create a figure and a set of subplots
        if plot_mus and plot_nus:
            rows = 3
        elif plot_nus or plot_mus:
            rows = 2
        fig, axes = plt.subplots(rows, agents_to_plot, figsize=(5 * agents_to_plot, 4 * rows), sharex=True)
        axes = axes.reshape(rows, -1)

        # Check if we have a single column of agents, adjust axes array for indexing
        if agents_to_plot == 1:
            axes = axes[:, np.newaxis]

        # Plotting lambda values
        for i in range(agents_to_plot):
            smoothed_lambda = moving_average(self.lambdas_list[:, i], window_size)
            smoothed_time = time[:len(smoothed_lambda)]  # Adjust time array for valid mode
            axes[0, i].plot(smoothed_time, smoothed_lambda, label='Lambda', color='blue')
            axes[0, i].set_title(f'Agent {i} Global multiplier', fontsize=20)
            axes[0, i].set_xlabel('Timestep', fontsize=22)
            axes[0, i].grid(True, linestyle='--', alpha=0.6)
            axes[0, i].tick_params(axis='both', which='major', labelsize=20)
            axes[0, i].legend(fontsize=22)
            
        if plot_mus:
            # Plotting nu values
            for i in range(agents_to_plot):
                smoothed_nu = moving_average(self.nus_list[:, i], window_size)
                smoothed_time = time[:len(smoothed_nu)]  # Adjust time array for valid mode
                axes[1, i].plot(smoothed_time, smoothed_nu, label='Nu', color='green')
                axes[1, i].set_title(f'Agent {i} local multiplier')
                axes[1, i].legend()

        if plot_nus:
            # Plotting mu values
            for i in range(agents_to_plot):
                smoothed_mu = moving_average(self.mus_list[:, i], window_size)
                smoothed_time = time[:len(smoothed_mu)]  # Adjust time array for valid mode
                axes[2, i].plot(smoothed_time, smoothed_mu, label='Mu', color='red')
                axes[2, i].set_title(f'Agent {i} consensus multiplier')
                axes[2, i].set_xlabel('Sample')
                axes[2, i].legend()

        # Set a single shared y-label
        fig.text(0.02, 0.5, 'Multiplier Value', va='center', rotation='vertical', fontsize=22)

        # Adjust layout to prevent overlap and make the graph more readable
        plt.tight_layout(rect=[0.04, 0, 1, 1])  # Add space for the y-label
        if plot_mus and plot_nus:
            plt.savefig("multipliers.pdf", format="pdf", bbox_inches="tight")
        elif plot_mus and not plot_nus:
            plt.savefig("lambda_mus.pdf", format="pdf", bbox_inches="tight")
        elif plot_nus and not plot_mus:
            plt.savefig("lambda_nus.pdf", format="pdf", bbox_inches="tight")
        else:
            plt.savefig(f"./results/lambdas{'cons' if cons else 'noCons'}.pdf", format="pdf", bbox_inches="tight")

        plt.show()


    def compare_algs2(self, epochs=100, consensus_steps=1, from_start=False, cum=True, num_runs=10):

        def runnin(arr):
            return np.cumsum(arr) / np.arange(1, len(arr) + 1)
        
        # Storage for multiple runs
        all_tot_actions_consensus = []
        all_tot_actions_no_consensus = []
        
        for i in range(num_runs):
            self.run_episode(epochs=epochs, consensus=True, consensus_steps=consensus_steps, from_start=from_start)
            tot_actions_consensus = self.actions_list[..., 0].sum(1)  # Selecting grid consumption and summing across agents 
            all_tot_actions_consensus.append(tot_actions_consensus)

            self.run_episode(epochs=epochs, consensus=False, from_start=from_start)
            tot_actions_no_consensus = self.actions_list[..., 0].sum(1)  # Selecting grid consumption and summing across agents 
            all_tot_actions_no_consensus.append(tot_actions_no_consensus)

        tot_demand = self.obs_history[..., 1].sum(1)  # Selecting the demand and summing it for every agent
        all_tot_actions_consensus = np.array(all_tot_actions_consensus)
        all_tot_actions_no_consensus = np.array(all_tot_actions_no_consensus)

        mean_actions_consensus = np.mean(all_tot_actions_consensus, axis=0)
        min_actions_consensus = np.min(all_tot_actions_consensus, axis=0)
        max_actions_consensus = np.max(all_tot_actions_consensus, axis=0)

        mean_actions_no_consensus = np.mean(all_tot_actions_no_consensus, axis=0)
        min_actions_no_consensus = np.min(all_tot_actions_no_consensus, axis=0)
        max_actions_no_consensus = np.max(all_tot_actions_no_consensus, axis=0)

        # Create subplots
        fig, ax = plt.subplots(figsize=(12, 6))

        # Plot total real demand with min and max
        ax.plot(runnin(tot_demand), color='red', label='Demand', linewidth=2)
        
        # Plot the sum of actions for all agents with consensus
        ax.plot(runnin(mean_actions_consensus), color='blue', label='Grid consumption with consensus', linewidth=2)
        ax.fill_between(np.arange(len(mean_actions_consensus)), runnin(min_actions_consensus), 
                        runnin(max_actions_consensus), color='blue', alpha=0.2)

        # Plot the sum of actions for all agents without consensus
        ax.plot(runnin(mean_actions_no_consensus), color='orange', label='Grid consumption without consensus', linewidth=2)
        ax.fill_between(np.arange(len(mean_actions_no_consensus)), runnin(min_actions_no_consensus), 
                        runnin(max_actions_no_consensus), color='orange', alpha=0.2)

        # Constraint Line
        ax.axhline(y=self.constraint, color='r', linestyle='-', linewidth=2)

        # Labels and Title with Larger Fonts
        ax.set_title('Grid Consumption Comparison', fontsize=22)
        ax.set_ylabel('Grid consumption (kWh)', fontsize=22)
        ax.set_xlabel('Timestep', fontsize=22)

        # Increase legend font size
        ax.legend(fontsize=22)

        # Increase tick font sizes
        ax.tick_params(axis='both', which='major', labelsize=16)
        ax.set_xlim(0, len(mean_actions_no_consensus) - 1)

        # Add grid
        ax.grid(True, linestyle='--', alpha=0.6)

        # Save as PDF
        plt.tight_layout()
        plt.savefig("./results/gridConsComp.pdf", format="pdf", bbox_inches="tight")

        plt.show()  # Show plot if running interactively


    def run_experiment_no_consensus(self, epochs=100, num_runs=10):

        def runnin(arr):
            return np.cumsum(arr) / np.arange(1, len(arr) + 1)
        
        # Storage for multiple runs
        all_tot_actions_no_consensus = []
        all_tot_obs = []
        
        for i in range(num_runs):
            self.run_episode_no_consensus(epochs=epochs)
            tot_actions_no_consensus = self.actions_list[..., 0].sum(1)  # Selecting grid consumption and summing across agents 
            tot_demand = self.obs_history[..., 1].sum(1)  # Selecting the demand and summing it for every agent
            all_tot_actions_no_consensus.append(tot_actions_no_consensus)
            all_tot_obs.append(tot_demand) 

        all_tot_actions_no_consensus = np.array(all_tot_actions_no_consensus)
        all_tot_obs = np.array(all_tot_obs)

        mean_actions_no_consensus = np.mean(all_tot_actions_no_consensus, axis=0)
        min_actions_no_consensus = np.min(all_tot_actions_no_consensus, axis=0)
        max_actions_no_consensus = np.max(all_tot_actions_no_consensus, axis=0)

        mean_obs = np.mean(all_tot_obs, axis=0)
        min_obs = np.min(all_tot_obs, axis=0)
        max_obs = np.max(all_tot_obs, axis=0)

        # Create subplots
        fig, ax = plt.subplots(figsize=(10, 10))

        # Plot total real demand with min and max
        ax.plot(runnin(mean_obs), color='red', label='Mean Demand', linewidth=2)
        ax.fill_between(np.arange(len(mean_obs)), runnin(min_obs), runnin(max_obs), color='red', alpha=0.2)

        # Plot the sum of actions for all agents without consensus
        ax.plot(runnin(mean_actions_no_consensus), color='blue', label='Grid consumption without consensus', linewidth=2)
        ax.fill_between(np.arange(len(mean_actions_no_consensus)), runnin(min_actions_no_consensus), 
                        runnin(max_actions_no_consensus), color='blue', alpha=0.2)

        # Constraint Line
        ax.axhline(y=self.constraint, color='r', linestyle='-', linewidth=2)

        # Labels and Title with Larger Fonts
        ax.set_title('Grid Consumption Without Consensus', fontsize=22)
        ax.set_xlabel('Timestep', fontsize=22)
        ax.set_ylabel('Grid consumption (kWh)', fontsize=22)

        # Increase legend font size
        ax.legend(fontsize=18)

        # Increase tick font sizes
        ax.tick_params(axis='both', which='major', labelsize=20)
        ax.set_xlim(0, len(mean_obs) - 1)

        # Add grid
        ax.grid(True, linestyle='--', alpha=0.6)

        # Save as PDF
        plt.tight_layout()
        plt.savefig("gridConsNoConsensus.pdf", format="pdf", bbox_inches="tight")

        plt.show()  # Show plot if running interactively

    
    def run_experiment_no_consensus_summary(
        self, epochs=100, num_runs=10, window_size=10
    ):
        def runnin(arr):
            return np.cumsum(arr) / np.arange(1, len(arr) + 1) 
        
        def cumulative(arr):
            return np.cumsum(runnin(arr)) 
        
        all_tot_postponed = []
        all_tot_actions_no_consensus = []
        
        for i in range(num_runs):
            self.run_episode_no_consensus(epochs=epochs)            
            running_avg_tot_actions = runnin(self.actions_list[..., 0].sum(axis=1))
            all_tot_actions_no_consensus.append(running_avg_tot_actions)
            
            cum_running_tot_postponed = cumulative(self.obs_history[..., 3].sum(axis=1))
            all_tot_postponed.append(cum_running_tot_postponed)

        all_tot_actions_no_consensus = np.array(all_tot_actions_no_consensus)  # (num_runs, epochs)
        all_tot_postponed = np.array(all_tot_postponed)                        # (num_runs, epochs)

        # ---- Actions stats (last epoch) ----
        mean_actions_no_consensus_last = np.mean(all_tot_actions_no_consensus[:, -1])
        min_actions_no_consensus_last  = np.min(all_tot_actions_no_consensus[:, -1])
        max_actions_no_consensus_last  = np.max(all_tot_actions_no_consensus[:, -1])

        # ---- Postponed stats (for derivative calculation) ----
        mean_postponed_by_epoch = np.mean(all_tot_postponed, axis=0)  # shape: (epochs,)

        # We want to measure how fast the mean is changing near the end.
        # Simple approach: compare the average of the last `window_size` steps
        # to the average of the `window_size` steps before that.
        # Make sure we have enough epochs for that:
        if epochs < 2 * window_size:
            # If the experiment doesn't have enough epochs, fallback to a smaller window or a direct difference
            raise ValueError(f"Need at least 2 * window_size epochs to do a windowed derivative. Got {epochs}.")

        last_window_start = epochs - window_size
        prev_window_start = epochs - 2 * window_size
        
        avg_last_window = mean_postponed_by_epoch[last_window_start : epochs].mean()
        avg_prev_window = mean_postponed_by_epoch[prev_window_start : last_window_start].mean()

        # If each epoch is considered 1 time unit, the denominator is window_size
        mean_postponed_derivative_last = (avg_last_window - avg_prev_window) / float(window_size)

        # ---- (Optional) min/max postponed at last epoch ----
        min_postponed_last = np.min(all_tot_postponed[:, -1])
        max_postponed_last = np.max(all_tot_postponed[:, -1])

        return (
            mean_actions_no_consensus_last,
            min_actions_no_consensus_last,
            max_actions_no_consensus_last,
            mean_postponed_derivative_last,  # <-- now returning a window-based derivative
            min_postponed_last,
            max_postponed_last
        )


    def compare_algs3(self, epochs=100, consensus_steps=1, from_start=False, cum=True, 
                    constraints=[0.2, 0.25, 0.28, 0.34, 0.4, 0.5], run=True, num_runs=10):

        def runnin(arr):
            return np.cumsum(arr) / np.arange(1, len(arr) + 1) if run else arr
        
        def cumulative(arr):
            return np.cumsum(runnin(arr)) if cum else runnin(arr)
        
        # Storage for multiple runs
        all_unmet_consensus = {constraint: [] for constraint in constraints}
        all_unmet_no_consensus = {constraint: [] for constraint in constraints}

        for constraint in constraints:
            self.set_costraint(thresh=constraint)
            
            for i in range(num_runs):
                # Run episode with consensus
                self.run_episode(epochs=epochs, consensus=True, consensus_steps=consensus_steps, from_start=from_start)
                unmet_cons = cumulative(self.obs_history[:, :, 1].sum(1) - self.actions_list.sum((1, 2)))
                all_unmet_consensus[constraint].append(unmet_cons)
                
                # Run episode without consensus
                self.run_episode(epochs=epochs, consensus=False, from_start=from_start)
                unmet_sin_cons = cumulative(self.obs_history[:, :, 1].sum(1) - self.actions_list.sum((1, 2)))
                all_unmet_no_consensus[constraint].append(unmet_sin_cons)

        # Create figure for unmet demand with consensus
        fig_consensus, ax_consensus = plt.subplots(figsize=(12, 6))
        # Create figure for unmet demand without consensus
        fig_no_consensus, ax_no_consensus = plt.subplots(figsize=(12, 6))
        

        for constraint in constraints:
            runs_consensus = np.array(all_unmet_consensus[constraint])
            mean_consensus = np.mean(runs_consensus, axis=0)
            min_consensus = np.min(runs_consensus, axis=0)
            max_consensus = np.max(runs_consensus, axis=0)
            
            runs_no_consensus = np.array(all_unmet_no_consensus[constraint])
            mean_no_consensus = np.mean(runs_no_consensus, axis=0)
            min_no_consensus = np.min(runs_no_consensus, axis=0)
            max_no_consensus = np.max(runs_no_consensus, axis=0)
            
            # Plot the results in the respective figures
            ax_consensus.plot(mean_consensus, label=f'Constraint={constraint:.2f}', linewidth=2)
            ax_consensus.fill_between(np.arange(len(mean_consensus)), min_consensus, max_consensus, alpha=0.3)

            ax_no_consensus.plot(mean_no_consensus, label=f'Constraint={constraint:.2f}', linewidth=2)
            ax_no_consensus.fill_between(np.arange(len(mean_no_consensus)), min_no_consensus, max_no_consensus, alpha=0.3)
        
        # Configure the figure for unmet demand with consensus
        ax_consensus.legend(fontsize=22)
        ax_consensus.axhline(y=0, color='r', linestyle='-', linewidth=2)
        ax_consensus.set_title('Unmet Demand with Consensus', fontsize=22)
        ax_consensus.set_ylabel('Unmet Demand (kWh)', fontsize=22)
        ax_consensus.set_xlabel('Epochs', fontsize=22)
        ax_consensus.set_xlim(0, len(mean_consensus) - 1)
        ax_consensus.set_ylim(-800, 500)
        ax_consensus.tick_params(axis='both', which='major', labelsize=20)
        ax_consensus.grid(True, linestyle='--', alpha=0.6)
        plt.tight_layout()
        fig_consensus.savefig('./results/unmet_demand_with_consensus.pdf', format="pdf", bbox_inches="tight")
        
        # Configure the figure for unmet demand without consensus
        ax_no_consensus.legend(fontsize=22)
        ax_no_consensus.axhline(y=0, color='r', linestyle='-', linewidth=2)
        ax_no_consensus.set_title('Unmet Demand without Consensus', fontsize=22)
        ax_no_consensus.set_ylabel('Unmet Demand (kWh)', fontsize=22)
        ax_no_consensus.set_xlabel('Epochs', fontsize=22)
        ax_no_consensus.set_xlim(0, len(mean_consensus) - 1)
        ax_no_consensus.set_ylim(-800, 500)
        ax_no_consensus.tick_params(axis='both', which='major', labelsize=20)
        ax_no_consensus.grid(True, linestyle='--', alpha=0.6)
        plt.tight_layout()
        fig_no_consensus.savefig('./results/unmet_demand_without_consensus.pdf', format="pdf", bbox_inches="tight")

        # Show the plots
        plt.show()

    
