import matplotlib.pyplot as plt
import numpy as np
import itertools
import os
import json
from sklearn.preprocessing import LabelEncoder

DATASET_PATH = "dataset/"


def get_tables_path(dataset):
    if dataset == "diabetes":  # 8 features
        tables_path = DATASET_PATH + "diabetes.csv"
        target = "Outcome"
    elif dataset == "adult":  # 14 features
        tables_path = DATASET_PATH + "adult.csv"
        target = "income"
    elif dataset == "german_credit":  # 20 features
        tables_path = DATASET_PATH + "german_credit.csv"
        target = "credit_risk"
    elif dataset == "bank":  # 20 features
        tables_path = DATASET_PATH + "bank.csv"
        target = "y"
    elif dataset == "nhanes":  # 79
        tables_path = DATASET_PATH + "nhanes.csv"
        target = None
    elif dataset == "brca":  # 100 features
        tables_path = DATASET_PATH + "brca_small.csv"
        target = "BRCA_Subtype_PAM50"
    elif dataset == "flights":  # 639 features
        tables_path = DATASET_PATH + "flights.csv"
        target = None

    elif dataset == "independentlinear60":  # 59 features'
        tables_path = DATASET_PATH + "independentlinear60.csv"
        target = "y"
    elif dataset == "communitiesandcrime":  # 101 features
        tables_path = DATASET_PATH + "communitiesandcrime.csv"
        target = "y"
    elif dataset == "tuandromd":  # 241 features
        tables_path = DATASET_PATH + "TUANDROMD.csv"
        target = "Label"

    return tables_path, target


def get_data_and_explicand(data_size, base_data, random_state=42):
    data = base_data.sample(n=data_size + 1, random_state=random_state)
    explicand = data.iloc[[0]]
    data = data.iloc[1:]

    return data, explicand


def powerset(n):
    """Generate all combinations of 0s and 1s for a set of size n."""
    return list(itertools.product([0, 1], repeat=n))


def update_json_file(path, new_data, key):
    """Update or create a JSON file with new data under a specific key."""
    if os.path.exists(path):
        try:
            with open(path, "r+") as file:
                try:
                    data = json.load(file)
                except json.JSONDecodeError:
                    data = {}

                data[key] = new_data
                file.seek(0)
                file.truncate()
                json.dump(data, file, indent=4)
        except IOError as e:
            print(f"An error occurred while reading or writing to the file: {e}")
    else:
        with open(path, "w") as file:
            json.dump({key: new_data}, file, indent=4)


def crossentropyloss(pred, target):
    """Cross entropy loss that does not average across samples."""
    if pred.ndim == 1:
        pred = pred[:, np.newaxis]
        pred = np.concatenate((1 - pred, pred), axis=1)

    if pred.shape == target.shape:
        # Soft cross entropy loss.
        pred = np.clip(pred, a_min=1e-12, a_max=1 - 1e-12)
        return -np.sum(np.log(pred) * target, axis=1)
    else:
        # Standard cross entropy loss.
        return -np.log(pred[np.arange(len(pred)), target])


def mseloss(pred, target):
    """MSE loss that does not average across samples."""
    if len(pred.shape) == 1:
        pred = pred[:, np.newaxis]
    if len(target.shape) == 1:
        target = target[:, np.newaxis]
    return np.sum((pred - target) ** 2, axis=1)


def calculate_l2_norm(exact, estimated):
    differences = []

    for feature in exact:
        diff = np.array(exact[feature]) - np.array(estimated[feature])
        differences.append(diff)

    difference_matrix = np.vstack(differences)
    l2_norms = np.linalg.norm(difference_matrix, axis=0)
    return l2_norms.mean()


def calculate_normalized_l2_norm(exact, estimated):
    differences = []
    exacts = []

    for feature in exact:
        diff = np.array(exact[feature]) - np.array(estimated[feature])
        differences.append(diff)
        exacts.append(np.array(exact[feature]))

    difference_matrix = np.vstack(differences)
    l2_norms = np.linalg.norm(difference_matrix, axis=0) / np.linalg.norm(np.vstack(exacts), axis=0)
    return l2_norms.mean()


def calculate_objective(exact, S, b, estimated):
    """
    Calculate relative objective error for Banzhaf estimators.

    :param S: np.ndarray
        (2^n, n) powerset matrix.
    :param b: np.ndarray
        (2^n,) target vector.
    :return: float
        Relative objective error ||A * estimated_x - b||^2 - ||A * exact_x - b||^2 where A = S - 1/2.
    """
    A = S - 1 / 2

    exact_x = []
    estimated_x = []

    for feature in exact:
        exact_x.append(exact[feature])
        estimated_x.append(estimated[feature])
    exact_x = np.array(exact_x)
    estimated_x = np.array(estimated_x)

    exact_errors = np.linalg.norm(A @ exact_x - b, axis=0) ** 2
    estimated_errors = np.linalg.norm(A @ estimated_x - b, axis=0) ** 2

    error_diff = np.mean(estimated_errors - exact_errors)
    return error_diff


def calculate_gamma(exact, S, b):
    """
    Calculate gamma.

    :param S: np.ndarray
        (2^n, n) powerset matrix.
    :param b: np.ndarray
        (2^n,) target vector.
    :return: float
        ||b||_2^2 / ||A * exact_x ||_2^2.
    """
    A = S - 1 / 2

    exact_x = []
    for feature in exact:
        exact_x.append(exact[feature])
    exact_x = np.array(exact_x)

    return np.linalg.norm(b) ** 2 / np.linalg.norm(A @ exact_x) ** 2


def convert_boolean_and_encode(data):
    """
    Converts columns that should be boolean and encodes categorical variables.
    """
    for col in data.columns:
        data[col] = convert_boolean_and_encode_col(data[col])
    return data


def convert_boolean_and_encode_col(col):
    unique_values = col.dropna().unique()
    # Handling boolean data
    if set(unique_values).issubset({"True", "False", True, False}):
        col = col.map({"True": True, "False": False}).infer_objects()
    elif col.dtype.name == "category":
        # Handling categorical data
        if set(unique_values).issubset({"True", "False"}):
            col.cat.rename_categories({"True": True, "False": False}, inplace=True)
            col = col.astype(bool)
        else:
            # Encode other categorical variables
            le = LabelEncoder()
            col = le.fit_transform(col.astype(str))
    elif col.dtype == "object":
        # Encode object types that are not explicitly handled above
        le = LabelEncoder()
        col = le.fit_transform(col.astype(str))
    return col


def compute_condition_number(matrix):
    """
    Computes the condition number of a given matrix using its singular values.

    Parameters:
    matrix (np.ndarray): The matrix for which to compute the condition number.

    Returns:
    float: The condition number of the matrix.
    """
    # Compute the SVD of the matrix
    U, s, Vt = np.linalg.svd(matrix)

    # Filter out zero singular values to avoid division by zero
    non_zero_singular_values = s[s > np.finfo(float).eps * 100]

    # Compute the condition number (max singular value / min non-zero singular value)
    if len(non_zero_singular_values) > 0:
        gamma_max = non_zero_singular_values.max()
        gamma_min = non_zero_singular_values.min()
        condition_number = gamma_max / gamma_min
    else:
        condition_number = np.inf  # If all singular values are zero or nearly zero

    return condition_number
