import os
from torch import Tensor
import networkx as nx
import pandas as pd
import numpy as np
import pickle

def load_causal_graph(filename: str):
    r"""
    Load causal graph from a pickle file.
    
    Args:
        filename (str): Path to the pickle file (with or without 
            '.pkl' extension).

    Returns:
        networkx.DiGraph: The causal graph loaded from the pickle 
            file.

    Raises:
        FileNotFoundError: If the file does not exist.
        TypeError: If the loaded object is not a networkx.DiGraph.
    """
    if not filename.endswith(".pkl"):
        filename += ".pkl"
    if not os.path.isfile(filename):
        raise FileNotFoundError(
            f"Causal graph '{filename}' does not exist.")
    with open(filename, "rb") as f:
        causal_graph = pickle.load(f)
    if type(causal_graph) is not nx.DiGraph:
        raise TypeError(
            f"Expected a networkx.DiGraph, but got {type(causal_graph)}.")
    return causal_graph

def convert_df_to_npz(
        df: pd.DataFrame, 
        save: bool = False, 
        filename: str = "data.npz"
    ) -> dict:
    r"""
    Convert a full pandas DataFrame into a savable structure, 
    optionally saving it as a .npz file. Works for mixed-type data: 
    strings, numbers, timestamps, etc.

    Args:
        df (pd.DataFrame): Input DataFrame.
        save (bool): Whether to save the .npz file.
        filename (str): Filename to save. Can be with or without 
            `.npz`.

    Returns:
        dict: Dictionary with 'data', 'columns', and 'index'.

    Raises:
        TypeError: If input is not a pandas DataFrame.
        ValueError: If DataFrame has no columns or non-string column 
            names.
    """
    if not isinstance(df, pd.DataFrame):
        raise TypeError("Input must be a pandas DataFrame.")    
    data = df.to_numpy()
    columns = df.columns.to_numpy()
    if len(columns) == 0 or not all(isinstance(col, str) for col in columns):
        raise ValueError("DataFrame must have at least one column, "
                            "and all column names must be strings.")    
    index = df.index.to_numpy() if not isinstance(df.index, pd.RangeIndex) \
            else None
    if save:
        if not filename.endswith(".npz"):
            filename += ".npz"
        if index is not None:
            np.savez(filename, data=data, columns=columns, index=index)
        else:
            np.savez(filename, data=data, columns=columns)
    return {"data": data, "columns": columns, "index": index}


def load_npz_as_df(filename: str) -> pd.DataFrame:
    r"""
    Load a .npz file saved from a DataFrame and reconstruct it fully.

    Args:
        filename (str): Path to the .npz file (without or with 
            `.npz` extension).

    Returns:
        pd.DataFrame: Reconstructed DataFrame with original data, 
            columns, and index.

    Raises:
        ValueError: If file is not a valid .npz file or does not 
            contain required structure.
    """
    if not filename.endswith(".npz"):
        filename += ".npz"
    if not filename.endswith(".npz"):
        raise ValueError(f"Expected a .npz file, got '{filename}'.")
    npz = np.load(filename, allow_pickle=True)
    if "data" not in npz or "columns" not in npz:
        raise ValueError(f"'{filename}' is not a valid saved DataFrame structure.")
    columns = npz["columns"]
    if len(columns) == 0 or not all(isinstance(col, str) for col in columns):
        raise ValueError("The .npz file must contain at least one column, and "
                            "all column names must be strings.")    
    df = pd.DataFrame(data=npz["data"], columns=npz["columns"])
    if "index" in npz:
        df.index = npz["index"]
    return df

def make_dataframe(
    design_variables: list[str],
    outcome_variables: list[str],
    train_x: Tensor,
    train_y: Tensor
) -> pd.DataFrame:
    r"""
    Convert input and objective tensors to a pandas DataFrame.

    Args:
        design_variables (list[str]): List of strings representing 
            the names of the input variables.
        outcome_variables (list[str]): List of strings representing 
            the names of the objective variables.
        train_x (Tensor): Tensor of shape (n_samples, input_dim), 
            containing the input (design) variables.
        train_y (Tensor): Tensor of shape (n_samples, num_objectives), 
            containing the corresponding objective values.

    Returns:
        pd.DataFrame: DataFrame with design and outcome variables as 
            columns.

    Raises:
        ValueError: If design_variables or outcome_variables count 
            does not match tensor dimensions.
    """
    # Convert tensors to numpy
    X_np = train_x.detach().cpu().numpy()
    Y_np = train_y.detach().cpu().numpy()

    # Number of inputs and objectives
    input_dim = X_np.shape[-1]
    output_dim = Y_np.shape[-1]

    # Validate design and objective variable names
    if len(design_variables) != input_dim:
        raise ValueError(
            f"`design_variables` count {len(design_variables)} must match input dimension {input_dim}."
        )
    if len(outcome_variables) != output_dim:
        raise ValueError(
            f"`outcome_variables` count {len(outcome_variables)} must match output dimension {output_dim}."
        )
    input_columns = design_variables
    output_columns = outcome_variables

    # Create DataFrame
    df = pd.DataFrame(
        data=np.hstack([X_np, Y_np]),
        columns=input_columns + output_columns
    )

    return df