import os
import numpy as np
import pandas as pd
import pandapower as pp
import simbench as sb
import torch

def get_pp_sources(output_base_dir, grid_type):
    """Get the list of solved source networks for a specific grid type.
    Args:
        output_base_dir (str): Base directory where datasets are stored.
        grid_type (str): The type of sb grid (e.g., '1-LV-rural1--0-no_sw', '1-MV-urban--1-no_sw', etc.)
    """
    dataset_source_path = os.path.join(output_base_dir, grid_type, 'train', 'dataset_src.csv')
    sources_absolute = pd.read_csv(dataset_source_path, index_col=0)['src'].to_list()
    sources = [os.path.join(output_base_dir, grid_type, 'train', src.split('/')[-1]) for src in sources_absolute]
    return sources

def get_pyg_graphs(data_dir, grid_type, include_sources=True):
    """
    Load PyTorch Geometric graphs from the specified directory and grid type.
    Args:
        data_dir (str): Base directory where datasets are stored.
        grid_type (str): The type of sb grid (e.g., '1-LV-rural1--0-no_sw', '1-MV-urban--1-no_sw', etc.)
        include_sources (bool): Whether to include source network info in each graph.
    
    Returns:
        list of torch_geometric.data.Data: List of PyTorch Geometric Data objects.
    """
    dataset_path = os.path.join(data_dir, grid_type, 'train', 'dataset.pt')
    pyg_dataset = torch.load(dataset_path, weights_only=False)
    sources = None
    if include_sources:
        sources = get_pp_sources(data_dir, grid_type)
        i = 0
        for data, src in zip(pyg_dataset, sources):
            net = pp.from_json(src)
            print(net)
            print(net["_ppc"])
            pp.runpp(net, verbose=False)
            ppci = net["_ppc"]["internal"]
            data.ppci = ppci
            i += 1
            if i >= 10:
                break
    return pyg_dataset

def get_dist_grid_codes(scenario=0):
    # Create the codes for the distribution grid cases of Simbench (LV and MV and any combination of the two)
    codes = sb.collect_all_simbench_codes(scenario=scenario)
    dist_grid_codes = list(filter(lambda x: "no_sw" in x and ("-MV-" in x or "-LV-" in x), codes))
    return sorted(dist_grid_codes)

data_dir = 'data/ENGAGE_dataset/'
training_grids = get_dist_grid_codes(scenario=1)

for grid in training_grids:
    dataset = get_pyg_graphs(data_dir, grid, include_sources=True)
    filename = os.path.join(data_dir, grid, 'train', 'dataset_with_ppci.pt')
    torch.save(dataset, filename)
