import pandas as pd
import numpy as np
from scipy.stats import multivariate_normal, binom
import random
        
def columns_list(letter, n):
    if n == 1:
        return [letter]
    else:
        string_list = []
        for i in range(1, n + 1):
            string_list.append(f"{letter}{i}")
        return string_list

def generate_binomial(
    covariates_name: str, sample_size, p=0.3, dimension=1,
) -> pd.core.frame.DataFrame:
    data_df = pd.DataFrame(
            binom(p, n=1, size=(sample_size, dimension)).rvs(),
            dtype=float,
            columns=columns_list(covariates_name, dimension),
        )
    return data_df

def multi_logistic(X, Theta, return_prob=False):
    """
    Send the individuals to the centers using a multinomial logistic function.

    Parameters:
    X (numpy.ndarray): Input feature matrix of shape (m, n) where m is the number of samples and n is the number of features.
    Theta (numpy.ndarray): Weight matrix of shape (n, K) where K is the number of classes.
    K_classes (int): Number of classes.

    Returns:
    pandas.Series: A series with 'clientk'.
    """
    # Add intercept column
    m = X.shape[0]
    X_b = np.c_[np.ones((m, 1)), X]

    scores = X_b @ Theta.T
    exp_scores = np.exp(scores)
    probs = exp_scores / np.sum(exp_scores, axis=1, keepdims=True)

    K_classes = Theta.shape[0]
    list_client_name = [str(i) for i in range(1, K_classes + 1)]
    # Sample the realized class based on the probabilities
    realized_classes = np.array(
        [np.random.choice(list_client_name, p=probs[i]) for i in range(X.shape[0])]
    )

    clients_col = pd.Series(realized_classes, name="")

    return clients_col if not return_prob else probs


def generate_multivariate_normal(
    covariates_name: str or list,
    sample_size,
    mean_vector: int or np.array = 0,
    cov_matrix=1,
) -> pd.core.frame.DataFrame:
    return pd.DataFrame(
        multivariate_normal(mean_vector, cov_matrix).rvs(sample_size),
        columns=covariates_name,
    )

def generate_exponential(
    covariates_name: str or list,
    sample_size,
    scale=1,
    dimension=1,
) -> pd.core.frame.DataFrame:
    return pd.DataFrame(
        np.random.exponential(scale, (sample_size, dimension)), columns=covariates_name
    )

def generate_multivariate_mixtures_of_normals(
    covariates_name: str or list,
    n: int,
    dim,
    MeanVectors: list,
    CovarianceMatrices: list,
    mixture_weights: list = [
        0.25,
        0.25,
        0.25,
        0.125,
        0.125,
    ],  # summing to 1, default is 5 components
) -> pd.core.frame.DataFrame:
    # Initialize arrays
    samples = np.empty((n, dim))
    samples[:] = np.NaN
    componentlist = np.empty((n, 1))
    componentlist[:] = np.NaN
    # Generate samples
    for iter in range(n):
        # Get random number to select the mixture component with probability according to mixture weights
        DrawComponent = random.choices(
            range(len(mixture_weights)), weights=mixture_weights, cum_weights=None, k=1
        )[0]
        # Draw sample from selected mixture component
        DrawSample = np.random.multivariate_normal(
            MeanVectors[DrawComponent], CovarianceMatrices[DrawComponent]
        )
        # Store results
        componentlist[iter] = DrawComponent
        samples[iter, :] = DrawSample
    return pd.DataFrame(samples, columns=covariates_name)


def propensity_function(
    data: np.array, treatment_cols: list, gamma: np.array, clip=False
):
    dot_product = np.dot(data[treatment_cols].values, gamma)
    probabilities = 1 / (1 + np.exp(-dot_product))

    if clip:
        probabilities = np.clip(probabilities, 0.0001, 0.9999)

    return probabilities


def mu(X_outcome: np.array, beta: np.array):
    return np.dot(beta, X_outcome.T)


# Trim or clip the propensity scores
def trim_propensity_scores(
    df,
    gamma,
    treatment_cols,
    propensity_function=propensity_function,
    eta=0.01,
    clip=True,
):
    df["propensity"] = propensity_function(df, treatment_cols, gamma)
    if clip:
        df["propensity"] = df["propensity"].clip(eta, 1 - eta)
    else:
        df["propensity"] = df["propensity"].apply(
            lambda x: x if eta < x < 1 - eta else eta if x < eta else 1 - eta
        )
    return df


### Data generation
def generate_treatment(
    data: pd.core.frame.DataFrame,
    treatment_cols: list,
    gamma: np.array,
    treatment_name: str,
    propensity_function: callable,
):
    data[treatment_name] = np.random.binomial(
        1, propensity_function(data, treatment_cols, gamma)
    )
    return data


def model_noise(data: pd.core.frame.DataFrame, sigma2: float):
    return np.random.normal(0, sigma2, data.shape[0])


def generate_Y(
    data: pd.core.frame.DataFrame,
    beta_1: np.array,
    beta_0: np.array,
    covariates_name: list,
    treatment_name: str,
    h_k: float = 0,
    noise: str = "model_noise",
    Y_model='linear', # 'linear', 'polynomial', 'sinus', 'cosinus', 'logistic', 'complex_nonlinear'
):
    covariates = data[covariates_name].values
    treatment = data[treatment_name].values
    noise_values = data[noise].values

    if Y_model == 'linear':
        Y = np.where(
            treatment == 1, covariates @ beta_1 + h_k, covariates @ beta_0 + h_k
        )

    elif Y_model == "polynomial":
        polyn_coeff_treated = np.array([j / 10 for j in range(len(beta_1))])
        polyn_coeff_control = np.array([j / 10 - 1 / 3 for j in range(len(beta_0))])

        polynomial_terms_treated = np.sum(
            [
                (
                    polyn_coeff_treated[i] * (covariates[:, i] ** 2)
                    if i < len(covariates_name) / 2
                    else polyn_coeff_treated[i] * covariates[:, i]
                )
                for i in range(len(covariates_name))
            ],
            axis=0,
        )

        polynomial_terms_control = np.sum(
            [
                (
                    polyn_coeff_control[i] * (covariates[:, i] ** 2)
                    if i < len(covariates_name) / 2
                    else polyn_coeff_control[i] * covariates[:, i]
                )
                for i in range(len(covariates_name))
            ],
            axis=0,
        )

        Y = np.where(
            treatment == 1,
            polynomial_terms_treated + h_k + covariates[:, -1] * covariates[:, -2],
            polynomial_terms_control + h_k + covariates[:, 0] * covariates[:, -1],
        )

    elif Y_model == "sinus":
        # Sinus case
        Y = np.where(
            treatment == 1,
            np.sin(covariates[:, 0])
            + np.cos(covariates[:, 1]) / (1 + covariates[:, 2])
            + h_k,
            np.cos(np.sum(covariates, axis=1))
            + np.sin(covariates[:, 0]) / (1 + covariates[:, 1])
            + h_k,
        )

    elif Y_model == "cosinus":
        # Cosinus case
        Y = np.where(
            treatment == 1,
            np.cos(covariates @ beta_1 + h_k)
            + 2 * np.cos(covariates[:, 0] + covariates[:, 1]),
            np.cos(covariates @ beta_0 + h_k),
        )

    elif Y_model == "logistic":
        # Logistic case
        Y = np.where(
            treatment == 1,
            1 / (1 + np.exp(-covariates @ beta_1 - h_k)),
            1 / (1 + np.exp(-covariates @ beta_0 - h_k)),
        ) ## no intercept
        Y = np.random.binomial(1, Y)

    elif Y_model == "complex_nonlinear":
        # Complex nonlinear case
        Y = (
            data.apply(
                lambda x:
                # Treatment group (A = 1)
                (
                    (
                        x[covariates_name][0] * beta_1[0]  # Linear term for X_1
                        + x[covariates_name][1] * beta_1[1]  # Linear term for X_2
                        + x[covariates_name][2] * beta_1[2]  # Linear term for X_3
                        + x[covariates_name][3] * beta_1[3]  # Linear term for X_4
                        + x[covariates_name][4] * beta_1[4]  # Linear term for X_5
                        + x[covariates_name][0]
                        * x[covariates_name][1]
                        * beta_1[5]  # Interaction X_1 * X_2
                        + x[covariates_name][2] ** 2
                        * beta_1[6]  # Quadratic term for X_3
                        + np.sin(x[covariates_name][3]) * beta_1[7]  # Sin term for X_4
                        + np.log(np.abs(x[covariates_name][4]) + 1)
                        * beta_1[8]  # Log term for X_5
                        + x[covariates_name][5] * x[covariates_name][6] * beta_1[9]
                    )  # Interaction X_6 * X_7
                    if x[treatment_name] == 1
                    # Control group (A = 0)
                    else (
                        x[covariates_name][0] * beta_0[0]  # Linear term for X_1
                        + x[covariates_name][1] * beta_0[1]  # Linear term for X_2
                        + x[covariates_name][2] * beta_0[2]  # Linear term for X_3
                        + x[covariates_name][3] * beta_0[3]  # Linear term for X_4
                        + x[covariates_name][4] * beta_0[4]  # Linear term for X_5
                        + x[covariates_name][0]
                        * x[covariates_name][1]
                        * beta_0[5]  # Interaction X_1 * X_2
                        + x[covariates_name][2] ** 2
                        * beta_0[6]  # Quadratic term for X_3
                        + np.sin(x[covariates_name][3]) * beta_0[7]  # Sin term for X_4
                        + np.log(np.abs(x[covariates_name][4]) + 1)
                        * beta_0[8]  # Log term for X_5
                        + x[covariates_name][5]
                        * x[covariates_name][6]
                        * beta_0[9]  # Interaction X_6 * X_7
                    )
                ),
                axis=1,
            )
            + data[noise]
        )
    else:
        raise ValueError("Invalid value for Y_model. Use True or 'cosinus'.")

    Y += noise_values if Y_model != "logistic" else 0
    return Y