from typing import Literal
import pandas as pd
import os


def log_metrics(
    epoch_metrics,
    client_id,
    epoch,
    output_dir,
    evaluation_set: Literal["train", "test"],
    exp: int,
    write_from_beginning: bool = False,
):
    """
    Log metrics for a specific client at each epoch

    Parameters:
        epoch_metrics (dict):
        client_id (str):
        epoch (int):
        output_dir (str):
        evaluation_set (str):
        exp (int): experiment number
        write_from_beginning (bool):
    """

    # Ensure output directory exisits
    new_output_dir = os.path.join(
        output_dir, "clients_logs", f"experiment_{exp}", evaluation_set
    )
    os.makedirs(new_output_dir, exist_ok=True)

    # Path to the client's CSV file
    client_csv_path = os.path.join(new_output_dir, f"client_{client_id}_metrics.csv")

    # Convert dict into DataFrame
    epoch_df = pd.DataFrame([epoch_metrics], index=[epoch])

    if os.path.exists(client_csv_path) and not write_from_beginning:
        # If file exists
        df = pd.read_csv(client_csv_path, index_col=0)
        df = pd.concat([df, epoch_df])
    else:
        # If file doesn't exist
        df = epoch_df

    df.to_csv(client_csv_path, float_format="%.3f")


def log_aggregation_losses(
    epoch_loss,
    epoch,
    output_dir,
    exp: int,
    write_from_beginning: bool = False,
):
    """
    Log losses for the aggregation method of the server at each epoch

    Parameters:
        epoch_loss (dict):
        epoch (int):
        output_dir (str):
        exp (int): experiment number
        write_from_beginning (bool):
    """

    # Ensure output directory exisits
    new_output_dir = os.path.join(
        output_dir, "aggregation_logs", f"experiment_{exp}", "train"
    )
    os.makedirs(new_output_dir, exist_ok=True)

    # Path to the client's CSV file
    loss_csv_path = os.path.join(new_output_dir, f"loss.csv")

    # Convert dict into DataFrame
    epoch_df = pd.DataFrame([epoch_loss], index=[epoch])

    if os.path.exists(loss_csv_path) and not write_from_beginning:
        # If file exists
        df = pd.read_csv(loss_csv_path, index_col=0)
        df = pd.concat([df, epoch_df])
    else:
        # If file doesn't exist
        df = epoch_df

    df.to_csv(loss_csv_path, float_format="%.3f")
