import json
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx
import time
import pickle
from tqdm import tqdm
from src.marlAcrl import marlAcrl


def load_config(config_path):
    """Load configuration from a JSON file."""
    with open(config_path, 'r') as f:
        return json.load(f)


def run_experiment(num_nodes_list, p, repeats, max_trials, constraint_thresh, epochs, consensus_steps, lambda_step):
    """
    Runs scaling experiments for MARL algorithms with Erdos-Rényi graphs.

    Parameters
    ----------
    num_nodes_list : list of int
        List of graph sizes (number of nodes).
    p : float
        Fixed connection probability (Erdos-Rényi).
    repeats : int
        Number of repetitions for each graph size.
    max_trials : int
        Maximum attempts to generate a connected graph.
    constraint_thresh : float
        Constraint threshold for the manager.
    epochs : int
        Number of epochs for the main algorithm.
    consensus_steps : int
        Number of consensus steps.
    lambda_step : float
        Lambda step size.

    Returns
    -------
    results : list of (n, mean_time, std_time)
        Results containing (n, mean execution time, std deviation).
    """
    results = []
    plt.figure(figsize=(10, 6))

    for n in tqdm(num_nodes_list, desc="Scaling Experiment"):
        times = []

        for seed in range(repeats):
            # Generate a connected Erdos-Rényi graph
            for _ in range(max_trials):
                G = nx.erdos_renyi_graph(n, p, seed=np.random.randint(1e9))
                if nx.is_connected(G):
                    break
            else:
                raise ValueError(f"Could not generate a connected graph for n={n} with p={p}.")

            connections = list(G.edges)
            buildings = [1 if np.random.random() > 0.2 else 5 for _ in range(n)]

            # Initialize MARL manager
            manager = marlAcrl(
                connections=connections,
                buildings=buildings,
                models=['./models/ppo_cmarl_b1_T1000000_dd' if b == 1 else './models/ppo_cmarl_b5_T1000000_dd' for b in buildings],
                nu_val=1.5,
                lambda_val=1.5,
                derenv=True
            )

            manager.set_costraint(thresh=constraint_thresh)

            # Run algorithm and measure execution time
            start_time = time.time()
            manager.run_episode(
                epochs=epochs,
                consensus=True,
                lamdaStep=lambda_step,
                consensus_steps=consensus_steps
            )
            elapsed = time.time() - start_time
            times.append(elapsed)

        mean_time = np.mean(times)
        std_time = np.std(times)
        results.append((n, mean_time, std_time))

        # Compute running averages for lambdas
        samples, num_agents = manager.lambdas_list.shape
        t = np.arange(samples)
        running_lambdas = np.array([np.cumsum(manager.lambdas_list[:, i]) / np.arange(1, samples + 1)
                                    for i in range(num_agents)]).T

        mean_lambda = np.mean(running_lambdas, axis=1)
        min_lambda = np.min(running_lambdas, axis=1)
        max_lambda = np.max(running_lambdas, axis=1)

        plt.plot(t, mean_lambda, label=f'Mean Lambda - {n} agents')
        plt.fill_between(t, min_lambda, max_lambda, alpha=0.2)

        print(f"n = {n}, p = {p:.4f}, mean time = {mean_time:.4f}s ± {std_time:.4f}s")

    plt.title('Mean Lambda Across Agents (Running Average)', fontsize=22)
    plt.xlabel('Time', fontsize=22)
    plt.ylabel('Lambda Value', fontsize=22)
    plt.legend(fontsize=22)
    plt.grid(True)
    plt.xlim(0, len(mean_lambda))
    plt.tick_params(labelsize=20)
    plt.tight_layout()
    return results


def main(config_path):
    # Load configuration
    config = load_config(config_path)

    # Run experiment
    results = run_experiment(
        num_nodes_list=config["num_nodes_list"],
        p=config["p"],
        repeats=config["repeats"],
        max_trials=config["max_trials"],
        constraint_thresh=config["constraint_thresh"],
        epochs=config["epochs"],
        consensus_steps=config["consensus_steps"],
        lambda_step=config["lambda_step"]
    )

    # Save results and plot
    ns, mean_times, std_times = zip(*results)
    plt.savefig(config["output_graph"])
    print(f"Graph saved to {config['output_graph']}")

    # Plot execution times
    plt.figure(figsize=(8, 5))
    plt.errorbar(ns, mean_times, yerr=std_times, fmt='o-', capsize=5, label="Execution Time")
    plt.xlabel("Number of Agents (n)", fontsize=22)
    plt.ylabel("Execution Time (seconds)", fontsize=22)
    plt.title("Algorithm Scalability", fontsize=22)
    plt.grid(True)
    plt.legend(fontsize=22)
    plt.tick_params(labelsize=20)
    plt.savefig(config["output_times"])
    plt.show()


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="Run scalability experiments for MARL algorithms.")
    parser.add_argument(
        "--config", type=str, required=True,
        help="Path to the configuration file (e.g., scaling_config.json)."
    )
    args = parser.parse_args()

    main(args.config)
