import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import re
from tqdm import tqdm
from src.marlAcrl import marlAcrl
from matplotlib.colors import TwoSlopeNorm
from matplotlib.ticker import FuncFormatter
import json
import argparse

def extract_params_from_filename(filename):
    """Extract `nu` and `lambda` values from the filename using regex."""
    match = re.search(r'nu([-\d_]+)_lambd([-\d_]+)', filename)
    if match:
        nu_val = float(match.group(1).replace('_', '.'))
        lambda_val = float(match.group(2).replace('_', '.'))
        return nu_val, lambda_val
    return None, None

# Define the integer formatter function BEFORE using it
def integer_formatter(x, pos):
    return f"{int(x)}"

def load_config(config_path):
    """Load configuration from a JSON file."""
    with open(config_path, 'r') as f:
        return json.load(f)

def main(config_path, use_stored_results):
    # Load configuration
    config = load_config(config_path)
    models_folder = config["models_folder"]
    connections = config["connections"]
    buildings = config["buildings"]
    epochs = config["epochs"]
    num_runs = config["num_runs"]
    output_csv = config["output_csv"]
    output_heatmap_grid = config["output_heatmap_grid"]
    output_heatmap_pos = config["output_heatmap_pos"]

    if use_stored_results and os.path.exists(output_csv):
        print(f"Loading stored results from {output_csv}")
        df_results = pd.read_csv(output_csv)
        constraint_value = 47.14  # hardcoded constraint level of 47.14 for visualization if laoding pretrained ippo models
    else:
        # Load model filenames
        model_files = sorted(os.listdir(models_folder))
        
        # Filter and pair valid b1 and b5 models
        b1_models = [f for f in model_files if 'b1' in f]
        valid_models = [(b1, b1.replace('b1', 'b5')) for b1 in b1_models if b1.replace('b1', 'b5') in model_files]

        results = []

        # Run experiments for each valid model pair
        with tqdm(total=len(valid_models), desc="Running experiments") as pbar:
            for b1_model, b5_model in valid_models:
                nu_val, lambda_val = extract_params_from_filename(b1_model)
                if nu_val is None or lambda_val is None:
                    print(f"Skipping file {b1_model}: Could not extract parameters.")
                    pbar.update(1)
                    continue

                models = [os.path.join(models_folder, b5_model)] + \
                         [os.path.join(models_folder, b1_model)] * (len(buildings) - 2) + \
                         [os.path.join(models_folder, b5_model)]

                manager = marlAcrl(
                    connections=connections,
                    buildings=buildings,
                    models=models,
                    nu_val=nu_val,
                    lambda_val=lambda_val,
                    derenv=False
                )
                constraint_value = manager.constraint[0]

                mean_grid, min_grid, max_grid, mean_pos_der, min_pos, max_pos = manager.run_experiment_no_consensus_summary(
                    epochs=epochs, num_runs=num_runs
                )

                results.append({
                    "nu": nu_val,
                    "lambda": lambda_val,
                    "mean_grid": mean_grid,
                    "min_grid": min_grid,
                    "max_grid": max_grid,
                    "mean_pos_der": mean_pos_der,
                    "min_pos": min_pos,
                    "max_pos": max_pos
                })
    
                pbar.update(1)

        # Save results to a CSV file
        df_results = pd.DataFrame(results)
        df_results.to_csv(output_csv, index=False)
        print(f"Results saved to {output_csv}")

    # Generate heatmaps
    heatmap_data_grid = df_results.pivot("nu", "lambda", "mean_grid")
    heatmap_data_pos = df_results.pivot("nu", "lambda", "mean_pos_der")

    heatmap_data_pos_abs = heatmap_data_pos.abs()

    # **First Figure: Mean Grid Consumption Heatmap**
    fig_grid, ax_grid = plt.subplots(figsize=(8, 8))

    data_min = heatmap_data_grid.values.min()
    data_max = heatmap_data_grid.values.max()
    max_span = max(constraint_value - data_min, data_max - constraint_value)
    vmin = constraint_value - max_span
    vmax = constraint_value + max_span
    norm = TwoSlopeNorm(vmin=vmin, vcenter=47.14, vmax=vmax)# hardcoded constraint level of 47.14 for visualization if laoding pretrained ippo models

    sns.heatmap(heatmap_data_grid, annot=True, fmt=".0f", cmap="coolwarm", norm=norm, ax=ax_grid,
                 cbar_kws={"orientation": "horizontal", "pad": 0.1,"aspect": 50, "fraction": 0.03}, annot_kws={"size": 14})
    ax_grid.set_xlabel("Lambda", fontsize=18)
    ax_grid.set_ylabel("Nu", fontsize=18)
    # Apply formatter to both x and y ticks
    ax_grid.xaxis.set_major_formatter(FuncFormatter(integer_formatter))
    ax_grid.yaxis.set_major_formatter(FuncFormatter(integer_formatter))
    ax_grid.tick_params(labelsize=13)

    plt.tight_layout()
    plt.savefig(output_heatmap_grid, format="pdf", bbox_inches="tight")
    print(f"Grid Consumption Heatmap saved to {output_heatmap_grid}")
    plt.show()

    # **Second Figure: Absolute Mean Postponed Demand Derivative Heatmap**
    fig_pos, ax_pos = plt.subplots(figsize=(8, 8))

    sns.heatmap(heatmap_data_pos_abs, annot=True, fmt=".0f", cmap="coolwarm", ax=ax_pos, 
                cbar_kws={"orientation": "horizontal", "pad": 0.1, "aspect": 50, "fraction": 0.03}, annot_kws={"size": 14})
    ax_pos.set_xlabel("Lambda", fontsize=18)
    ax_pos.set_ylabel("Nu", fontsize=18)
    ax_pos.xaxis.set_major_formatter(FuncFormatter(integer_formatter))
    ax_pos.yaxis.set_major_formatter(FuncFormatter(integer_formatter))
    ax_pos.tick_params(labelsize=14)

    plt.tight_layout()
    plt.savefig(output_heatmap_pos, format="pdf", bbox_inches="tight")
    print(f"Postponed Demand Heatmap saved to {output_heatmap_pos}")
    plt.show()

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run model evaluation experiments.")
    parser.add_argument("--config", type=str, required=True, help="Path to the configuration file (e.g., evaluation_config.json).")
    parser.add_argument("--use_stored_results", action="store_true", help="Use stored CSV results if available instead of running experiments.")
    
    args = parser.parse_args()
    main(args.config, args.use_stored_results)
