#%%
import os
import argparse
import traceback
import numpy as np
import pandas as pd
import pandapower as pp
import simbench as sb
import torch
from tqdm import tqdm
from utils.training_utils import get_dist_grid_codes
from torch_geometric.utils import is_undirected

#%%
def get_pp_sources(output_base_dir, grid_type):
    """Get the list of solved source networks for a specific grid type.
    Args:
        output_base_dir (str): Base directory where datasets are stored.
        grid_type (str): The type of sb grid (e.g., '1-LV-rural1--1-no_sw', '1-MV-urban--1-no_sw', etc.)
    """
    dataset_source_path = os.path.join(output_base_dir, grid_type, 'train', 'dataset_src.csv')
    sources_absolute = pd.read_csv(dataset_source_path, index_col=0)['src'].to_list()
    sources = [os.path.join(output_base_dir, grid_type, 'train', src.split('/')[-1]) for src in sources_absolute]
    return sources

GRID_ACRONYMS = {
    "1-LV-rural1--1-no_sw": "LV1",
    "1-LV-rural2--1-no_sw": "LV2",
    "1-LV-rural3--1-no_sw": "LV3",
    "1-LV-semiurb4--1-no_sw": "LV4",
    "1-LV-semiurb5--1-no_sw": "LV5",
    "1-LV-urban6--1-no_sw": "LV6",
    "1-MV-rural--1-no_sw": "MV1", 
    "1-MV-semiurb--1-no_sw": "MV2",
    "1-MV-urban--1-no_sw": "MV3",
    "1-MV-comm--1-no_sw": "MV4",
}

def calculate_grid_statistics(data_dir, grid_type):
    """
    Load PyTorch Geometric graphs and augment them with extra pandapower network information.
    Args:
        data_dir (str): Base directory where datasets are stored.
        grid_type (str): The type of sb grid (e.g., '1-LV-rural1--1-no_sw', '1-MV-urban--1-no_sw', etc.)

    Returns:
        list of torch_geometric.data.Data: List of PyTorch Geometric Data objects.
    """
    dataset_path = os.path.join(data_dir, grid_type, 'train', 'dataset.pt')
    pyg_dataset = torch.load(dataset_path, weights_only=False)
    try:
        sources = get_pp_sources(data_dir, grid_type)
    except FileNotFoundError as e:
        traceback.print_exc()
        print("\nNeed to provide path to full ENGAGE dataset, with pandapower network sources. Use the '--data_dir' argument.\n")
        exit(1)

    stats = {
        "name": GRID_ACRONYMS[grid_type],
    }

    if "rural" in grid_type:
        stats["type"] = "rural"
    elif "semiurb" in grid_type:
        stats["type"] = "semi-urban"
    elif "urban" in grid_type:
        stats["type"] = "urban"
    elif "comm" in grid_type:
        stats["type"] = "commercial"

    num_buses = []
    num_lines = []
    avg_rx_ratio = []
    rated_voltage_kv = []
    min_line_loading_percent = []
    max_line_loading_percent = []
    avg_line_loading_percent = []

    for data, src in tqdm(zip(pyg_dataset, sources), total=len(pyg_dataset)):
        net = pp.from_json(src)

        # Buses
        num_buses.append(data.x.shape[0])

        # Lines
        num_lines.append(data.edge_index.shape[1] // 2)

        # R/X Ratio
        avg_rx_ratio.append(compute_rx_for_graph(data)['arithmetic_mean'])

        # Rated Voltage
        rated_voltage_kv.append(net['bus']['vn_kv'].mode()[0])

        # Line loading statistics
        line_loading = net['res_line']['loading_percent']

        min_line_loading_percent.append(line_loading.min())
        max_line_loading_percent.append(line_loading.max())
        avg_line_loading_percent.append(line_loading.mean())

    stats["num_buses"] = int(np.array(num_buses).mean())
    stats["num_lines"] = int(np.array(num_lines).mean())
    stats["avg_rx_ratio"] = np.array(avg_rx_ratio).mean()
    stats["rated_voltage_kv"] = np.array(rated_voltage_kv).mean()
    stats["min_line_loading_percent"] = np.array(min_line_loading_percent).min()
    stats["max_line_loading_percent"] = np.array(max_line_loading_percent).max()
    stats["avg_line_loading_percent"] = np.array(avg_line_loading_percent).mean()

    return stats

def compute_rx_for_graph(data):
    """
    Compute average R/X for a single PyG Data object.
    Assumes edge_attr = [trafo?, r_pu, x_pu, sc_voltage].
    Handles duplicate edges automatically.
    Expects duplicated edges to appear repeated (one after the other).
    """
    # Extract r and x
    if is_undirected(data.edge_index):
        # Skips duplicate edges.
        r = data.edge_attr[::2, 1]  # r_pu
        x = data.edge_attr[::2, 2]  # x_pu
    else:
        r = data.edge_attr[:, 1]  # r_pu
        x = data.edge_attr[:, 2]  # x_pu

    # Mask invalid or zero reactance
    mask = (x != 0)
    r = r[mask]
    x = x[mask]

    # Elementwise R/X
    rx = r / x

    # Averages
    arithmetic_mean = rx.mean().item()

    # Weighted by impedance magnitude
    z_mag = torch.sqrt(r**2 + x**2)
    imp_weighted_mean = (rx * z_mag).sum().item() / z_mag.sum().item()

    # Aggregate R/X = sum(R) / sum(X)
    aggregate_rx = r.sum().item() / x.sum().item()

    return {
        "arithmetic_mean": arithmetic_mean,
        "impedance_weighted_mean": imp_weighted_mean,
        "aggregate_rx": aggregate_rx,
        "n_edges_used": len(rx)
    }

def grid_stats_to_latex(grid_stats_list, caption="ENGAGE Dataset Statistics", label="tab:engage_stats"):
    """
    Convert a list of grid statistics dictionaries to a LaTeX table.
    
    Parameters:
    -----------
    grid_stats_list : list of dict
        List of dictionaries containing grid statistics
    caption : str
        Table caption
    label : str
        LaTeX label for referencing the table
    
    Returns:
    --------
    str : LaTeX table code
    """

    grid_stats_list = sorted(grid_stats_list, key=lambda d: d['name'])
    
    # Start building the LaTeX table
    latex = "\\begin{table}[t]\n"
    latex += "\\centering\n"
    latex += f"\\caption{{{caption}}}\n"
    latex += f"\\label{{{label}}}\n"
    
    # Define column alignment (left for name/type, center for numbers)
    latex += "\\begin{tabular}{llcccccccc}\n"
    latex += "\\toprule\n"

    # First header row with multicolumn for Line Loading
    latex += "\\textbf{Grid} & \\textbf{Type} & \\textbf{Buses} & \\textbf{Lines} & "
    latex += "\\textbf{Voltage} & \\textbf{R/X} & \\multicolumn{3}{c}{\\textbf{Line Loading}} \\\\\n"
    
    # Second header row for units and Min/Max/Avg
    latex += " & & & & \\textbf{(kV)} & \\textbf{Ratio} & "
    latex += "\\textbf{Min (\\%)} & \\textbf{Max (\\%)} & \\textbf{Avg (\\%)} \\\\\n"
    latex += "\\midrule\n"
    
    # Data rows
    for stats in grid_stats_list:
        name = stats['name']
        grid_type = stats['type'].capitalize()
        num_buses = stats['num_buses']
        num_lines = stats['num_lines']
        voltage = stats['rated_voltage_kv']
        rx_ratio = stats['avg_rx_ratio']
        min_load = stats['min_line_loading_percent']
        max_load = stats['max_line_loading_percent']
        avg_load = stats['avg_line_loading_percent']
        
        # Format the row
        latex += f"\\textbf{{{name}}} & {grid_type} & {num_buses} & {num_lines} & "
        latex += f"{voltage:.1f} & {rx_ratio:.2f} & "
        latex += f"{min_load:.2f} & {max_load:.2f} & {avg_load:.2f} \\\\\n"
    
    # Close the table
    latex += "\\bottomrule\n"
    latex += "\\end{tabular}\n"
    latex += "\\end{table}"
    
    return latex

def main(data_dir='data/ENGAGE_dataset/', save=False):
    training_grids = get_dist_grid_codes(scenario=1)

    grid_stats_list = []
    for grid in training_grids:
        print(grid)
        stats = calculate_grid_statistics(data_dir, grid)
        grid_stats_list.append(stats)
    print(grid_stats_to_latex(grid_stats_list))

    if save:
        results = []
        columns = grid_stats_list[0].keys()
        for stats in grid_stats_list:
            results.append(
                (
                    stats['name'],
                    stats['type'].capitalize(),
                    stats['num_buses'],
                    stats['num_lines'],
                    stats['rated_voltage_kv'],
                    stats['avg_rx_ratio'],
                    stats['min_line_loading_percent'],
                    stats['max_line_loading_percent'],
                    stats['avg_line_loading_percent']
                )
            )
        results_df = pd.DataFrame(results, columns=columns)
        results_df.to_csv('engage_statistics.csv')

# %%
if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Calculate ENGAGE dataset statistics"
    )
    parser.add_argument(
        "--data_dir",
        default="data/ENGAGE_dataset/",
        help="Path to the directory containing the ENGAGE grid datasets",
    )
    parser.add_argument(
        "--save",
        action="store_true"
    )
    args = parser.parse_args()
    exit(main(data_dir=args.data_dir, save=args.save))