"""
 estimate_ipw_and_aipw and plot_aipw functions 
"""

# Standard library imports
import copy
import random

# Third-party imports
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

# Scikit-learn imports
from sklearn.calibration import CalibratedClassifierCV
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import KFold, train_test_split, cross_val_predict
from sklearn.preprocessing import StandardScaler, LabelEncoder

# SciPy imports
from scipy.special import expit
from scipy.stats import multivariate_normal

# R integration (conditional)
try:
    import rpy2.robjects as ro
    from rpy2.robjects import pandas2ri
    from rpy2.robjects.packages import importr
    R_AVAILABLE = True
except ImportError:
    R_AVAILABLE = False


# ============================================================================
# UTILITY FUNCTIONS
# ============================================================================

def safe_divide(numerator, denominator, eps=1e-9, fill_value=np.nan):
    """Safely divide two arrays with optional epsilon handling."""
    if hasattr(denominator, '__len__'):
        if eps > 0:
            denominator = np.where(denominator == 0, eps, denominator)
            return numerator / denominator
        else:
            out = np.full_like(numerator, fill_value, dtype=float)
            np.divide(numerator, denominator, out=out, where=denominator != 0)
            return out
    else:
        return numerator / max(denominator, eps) if eps > 0 else numerator / denominator


def expit_safe(x, clip_bounds=None):
    """Compute sigmoid/logistic function with optional clipping."""
    probs = expit(x)
    if clip_bounds is not None:
        probs = np.clip(probs, clip_bounds[0], clip_bounds[1])
    return probs


def softmax(x):
    """Compute softmax function with numerical stability."""
    e_x = np.exp(x - np.max(x, axis=1, keepdims=True))
    return e_x / e_x.sum(axis=1, keepdims=True)


def normal_density(X, mu, Sigma):
    """Compute multivariate normal density."""
    if np.isscalar(Sigma):
        # Handle scalar covariance
        if len(X.shape) == 1:
            X = X.reshape(-1, 1)
        if np.isscalar(mu):
            mu = np.array([mu])
        return multivariate_normal.pdf(X, mu, Sigma)
    else:
        return multivariate_normal.pdf(X, mu, Sigma)


def multi_logistic(X, Theta, return_prob=False):
    """Multi-class logistic regression."""
    if X.shape[1] != Theta.shape[1]:
        raise ValueError(f"Dimension mismatch: X has {X.shape[1]} features, Theta has {Theta.shape[1]}")
    
    logits = X @ Theta.T
    if return_prob:
        return softmax(logits)
    else:
        return np.argmax(logits, axis=1)


def logistic_function_vectorized(df, X_cols, gamma):
    """Vectorized logistic function."""
    X = df[X_cols].values
    if len(gamma) != X.shape[1]:
        raise ValueError(f"Gamma length {len(gamma)} doesn't match X columns {X.shape[1]}")
    return expit_safe(X @ gamma)


def membership_weighting_vectorized(membership_weights, hat_e_k, eps=1e-9, keepdims=False):
    """Compute membership-weighted propensity scores."""
    membership_weights = np.array(membership_weights)
    hat_e_k = np.array(hat_e_k)
    
    if membership_weights.shape != hat_e_k.shape:
        raise ValueError(f"Shape mismatch: membership_weights {membership_weights.shape}, hat_e_k {hat_e_k.shape}")
    
    # Avoid division by zero
    denominator = np.sum(membership_weights, axis=1, keepdims=True)
    denominator = np.where(denominator == 0, eps, denominator)
    
    weighted_e = np.sum(membership_weights * hat_e_k, axis=1, keepdims=keepdims)
    return safe_divide(weighted_e, denominator.flatten() if not keepdims else denominator)


def compute_gamma(df, X_cols, regularization=1e-6, max_iter=1000):
    """Compute logistic regression coefficients."""
    X = df[X_cols].values
    y = df["W"].values
    
    if len(np.unique(y)) < 2:
        # If only one class, return zero coefficients
        return np.zeros(X.shape[1])
    
    try:
        model = LogisticRegression(
            fit_intercept=False,
            C=1/regularization,
            max_iter=max_iter,
            solver='lbfgs'
        )
        model.fit(X, y)
        return model.coef_[0]
    except Exception as e:
        print(f"Warning: LogisticRegression failed: {e}. Using zeros.")
        return np.zeros(X.shape[1])


def generalized_rf(df, feature_cols, target_col, use_grf=False, n_splits=5, rf_params=None):
    """Generalized random forest wrapper."""
    if rf_params is None:
        rf_params = {'n_estimators': 100, 'min_samples_leaf': 1, 'max_depth': None}
    
    X = df[feature_cols].values
    y = df[target_col].values
    
    if len(np.unique(y)) < 2:
        return np.full(len(df), np.nan)
    
    try:
        if use_grf and R_AVAILABLE:
            # R GRF implementation would go here
            pass
        
        # Fallback to sklearn
        model = RandomForestClassifier(**rf_params)
        if len(y) < n_splits:
            model.fit(X, y)
            return model.predict_proba(X)[:, 1]
        else:
            return cross_val_predict(model, X, y, cv=n_splits, method='predict_proba')[:, 1]
    except Exception as e:
        print(f"Warning: RF failed: {e}")
        return np.full(len(df), np.nan)


def generate_W_sequential(df, treatment_cols, scenario, return_p=False):
    """Generate sequential treatment assignment."""
    n = len(df)
    if scenario == "random":
        p = np.random.uniform(0.3, 0.7, n)
    elif scenario == "linear":
        X = df[treatment_cols].values
        p = expit_safe(X @ np.random.randn(X.shape[1]))
    else:
        p = np.full(n, 0.5)  # Default
    
    if return_p:
        return p
    else:
        return np.random.binomial(1, p)


def MW_estimation(df, X_cols, T_rounds=10, E_local_steps=5, eta=0.01, B_batch_size=32, scale=False):
    """Basic MW estimation placeholder."""
    n_clients = len(df['client'].unique())
    n_samples = len(df)
    return np.random.dirichlet(np.ones(n_clients), n_samples)


def MW_estimation_pooled(df, X_cols, scale=False, T_rounds=10, E_local_steps=5, eta=0.01, B_batch_size=32):
    """Pooled MW estimation."""
    return MW_estimation(df, X_cols, T_rounds, E_local_steps, eta, B_batch_size, scale)


# ============================================================================
# NEURAL NETWORK COMPONENTS
# ============================================================================

class MembershipNN(nn.Module):
    """Neural network for membership prediction."""
    def __init__(self, input_dim, hidden_dim=64, num_classes=2):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim, hidden_dim//2),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim//2, num_classes)
        )
    
    def forward(self, x):
        return self.network(x)


# ============================================================================
# MAIN SIMULATION CLASS
# ============================================================================

class Simulations_Fed:
    """Simplified simulation class for IPW/AIPW estimation."""
    
    def __init__(self, client_params_dict, estimator="aipw", n_simulations=100, fixed_design=False):
        self.client_params_dict = client_params_dict
        self.estimator = estimator
        self.n_simulations = n_simulations
        self.fixed_design = fixed_design
        
        # Extract basic parameters
        self.clients_list = [k for k in client_params_dict.keys() if k.startswith('client')]
        
        # Assume standard setup
        self.dim_x = len(client_params_dict.get('client1', {}).get('mean_covariates', [1]))
        self.treatment_cols = [f"X{i}" for i in range(1, self.dim_x + 1)]
        self.outcome_cols = ["Y"]
        self.sorting_columns = self.treatment_cols  # For membership weighting
        
        # Neural network components
        self.label_encoder = None
        self.membership_nn = None
        
        # Show hidden variables flag
        self.show_hidden_variables = False
    
    def combine_data(self):
        """Generate combined federated data."""
        dfs = []
        for i, client in enumerate(self.clients_list):
            params = self.client_params_dict[client]
            n = params.get('sample_size', 100)
            
            # Generate covariates
            mean_cov = params.get('mean_covariates', np.zeros(self.dim_x))
            cov_cov = params.get('cov_covariates', np.eye(self.dim_x))
            
            X = np.random.multivariate_normal(mean_cov, cov_cov, n)
            
            # Generate treatment
            gamma = params.get('gamma', np.zeros(self.dim_x))
            p_treat = expit_safe(X @ gamma)
            W = np.random.binomial(1, p_treat)
            
            # Generate outcome
            beta_1 = params.get('beta_1', np.ones(self.dim_x + 1))
            beta_0 = params.get('beta_0', np.zeros(self.dim_x + 1))
            sigma2 = params.get('sigma2', 1.0)
            
            X_with_intercept = np.column_stack([np.ones(n), X])
            Y = np.where(W == 1, 
                        X_with_intercept @ beta_1,
                        X_with_intercept @ beta_0) + np.random.normal(0, np.sqrt(sigma2), n)
            
            # Create DataFrame
            df_client = pd.DataFrame(X, columns=self.treatment_cols)
            df_client['W'] = W
            df_client['Y'] = Y
            df_client['client'] = client
            
            # Add hidden variables if requested
            if self.show_hidden_variables:
                df_client['e*'] = p_treat
                df_client['ITE*'] = X_with_intercept @ (beta_1 - beta_0)
            
            dfs.append(df_client)
        
        return pd.concat(dfs, ignore_index=True)
    
    def make_data_by_H_given_X(self):
        """Alternative data generation method."""
        return self.combine_data()  # Simplified
    
    def train_NN(self, df, rerun_NN=False):
        """Train neural network for membership prediction."""
        if self.membership_nn is not None and not rerun_NN:
            # Use existing model
            X = df[self.sorting_columns].values
            X_tensor = torch.FloatTensor(X)
            self.membership_nn.eval()
            with torch.no_grad():
                outputs = self.membership_nn(X_tensor)
                probs = torch.softmax(outputs, dim=1)
            return probs.numpy()
        
        # Train new model
        X = df[self.sorting_columns].values
        y = df['client'].values
        
        # Encode labels
        if self.label_encoder is None:
            self.label_encoder = LabelEncoder()
        y_encoded = self.label_encoder.fit_transform(y)
        
        # Create and train model
        input_dim = X.shape[1]
        num_classes = len(np.unique(y_encoded))
        self.membership_nn = MembershipNN(input_dim, num_classes=num_classes)
        
        # Simple training loop
        X_tensor = torch.FloatTensor(X)
        y_tensor = torch.LongTensor(y_encoded)
        
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(self.membership_nn.parameters())
        
        self.membership_nn.train()
        for epoch in range(50):  # Quick training
            optimizer.zero_grad()
            outputs = self.membership_nn(X_tensor)
            loss = criterion(outputs, y_tensor)
            loss.backward()
            optimizer.step()
        
        # Return probabilities
        self.membership_nn.eval()
        with torch.no_grad():
            outputs = self.membership_nn(X_tensor)
            probs = torch.softmax(outputs, dim=1)
        
        return probs.numpy()

    def estimate_ipw_and_aipw(
        self,
        propensity_estimation: str,
        propensity_federations: list,
        ipw_final_estimators: list = None,
        aipw_final_estimators: list = None,
        outcome_estimators: list = None,
        generate_by_center: bool = True,
        use_grf: bool = False,
        params_for_federation_omega: dict = None,
        params_for_federation: dict = None,
        rerun_NN: bool = False,
        scale: bool = False,
        ht_or_hajek: str = "HT",
        inconsistent_e1: bool = False,
        run_ipw=True,
        run_aipw=True
    ):
        """Main IPW/AIPW estimation function."""
        print(f"Running IPW: {run_ipw}, AIPW: {run_aipw}")
        if not run_ipw and not run_aipw:
            return {}, {}, None, {}

        if "oracle" in propensity_federations:
            self.show_hidden_variables = True

        # Initialize results dictionaries
        dict_structure = {f"client{i}": [] for i in range(1, len(self.clients_list) + 1)}
        dict_structure["total_data"] = []

        local_ipw_estimates, estimators_ipw = {}, {}
        if run_ipw:
            ipw_keys = ["oracle", "pool", "local", "MW", "MW_NN"]
            for key in ipw_keys:
                local_ipw_estimates[key] = copy.deepcopy(dict_structure)
            if ipw_final_estimators:
                estimators_ipw = {est: [] for est in ipw_final_estimators}

        local_aipw_estimates, estimators_aipw = {}, {}
        if run_aipw:
            aipw_keys = ["oracle", "pool", "MW", "MW_NN"]
            for key in aipw_keys:
                local_aipw_estimates[key] = copy.deepcopy(dict_structure)
            if aipw_final_estimators:
                estimators_aipw = {est: [] for est in aipw_final_estimators}

        X_cols = [f"X{i}" for i in range(1, self.dim_x + 1)]

        for sim_num in tqdm(range(self.n_simulations), desc="Simulations"):
            new_cols_for_df = {}
            
            # Generate data
            if generate_by_center:
                df = self.combine_data()
            else:
                df = self.make_data_by_H_given_X()

            if df.empty:
                continue

            # Propensity score estimation
            if "pool" in propensity_federations:
                if propensity_estimation == "logistic":
                    hat_gamma_pool = compute_gamma(df, self.treatment_cols)
                    new_cols_for_df["hat_e_pool"] = logistic_function_vectorized(
                        df, self.treatment_cols, hat_gamma_pool
                    )
                elif propensity_estimation == "random_forest":
                    new_cols_for_df["hat_e_pool"] = generalized_rf(
                        df, self.treatment_cols, "W", use_grf=use_grf
                    )

            # Local propensity scores
            cols_local_e_k = []
            if any(p in propensity_federations for p in ["MW", "local", "MW_NN"]):
                for k in range(1, len(self.clients_list) + 1):
                    client_str = f"client{k}"
                    df_k = df[df["client"] == client_str]
                    
                    if df_k.empty or len(df_k['W'].unique()) < 2:
                        new_cols_for_df[f"hat_e_local_{k}"] = np.full(len(df), 0.5)
                        continue
                    
                    if propensity_estimation == "logistic":
                        hat_gamma_k = compute_gamma(df_k, self.treatment_cols)
                        new_cols_for_df[f"hat_e_local_{k}"] = logistic_function_vectorized(
                            df, self.treatment_cols, hat_gamma_k
                        )
                    elif propensity_estimation == "random_forest":
                        new_cols_for_df[f"hat_e_local_{k}"] = generalized_rf(
                            df_k, self.treatment_cols, "W", use_grf=use_grf
                        )
                
                cols_local_e_k = [f"hat_e_local_{k}" for k in range(1, len(self.clients_list) + 1)]

            # Oracle propensity scores
            if "oracle" in propensity_federations:
                for k in range(1, len(self.clients_list) + 1):
                    client_str = f"client{k}"
                    if propensity_estimation == "logistic":
                        gamma_true = self.client_params_dict[client_str]["gamma"]
                        new_cols_for_df[f"e_oracle_{k}"] = logistic_function_vectorized(
                            df, self.treatment_cols, gamma_true
                        )

            # Membership weighting
            if "MW" in propensity_federations:
                df_MW_values = MW_estimation_pooled(df, self.sorting_columns, scale=scale)
                for i in range(len(self.clients_list)):
                    new_cols_for_df[f"mw_{i+1}"] = df_MW_values[:, i]
                
                if cols_local_e_k:
                    e_local_values = np.array([new_cols_for_df[col] for col in cols_local_e_k]).T
                    new_cols_for_df["hat_e_MW"] = membership_weighting_vectorized(
                        df_MW_values, e_local_values
                    )

            # Neural network membership weighting
            if "MW_NN" in propensity_federations:
                nn_probs = self.train_NN(df, rerun_NN=rerun_NN)
                
                for i in range(len(self.clients_list)):
                    new_cols_for_df[f"mw_nn_{i+1}"] = nn_probs[:, i]
                
                if cols_local_e_k:
                    e_local_values = np.array([new_cols_for_df[col] for col in cols_local_e_k]).T
                    new_cols_for_df["hat_e_MW_NN"] = membership_weighting_vectorized(
                        nn_probs, e_local_values
                    )

            # Update dataframe with new columns
            for col, values in new_cols_for_df.items():
                df[col] = values

            # Compute IPW estimates (simplified)
            if run_ipw:
                for method in local_ipw_estimates.keys():
                    if f"hat_e_{method}" in df.columns or method == "oracle":
                        e_col = f"hat_e_{method}" if method != "oracle" else "e_oracle"
                        if e_col in df.columns:
                            # Simple IPW calculation
                            e_vals = np.clip(df[e_col], 0.01, 0.99)
                            ipw_est = np.mean(df["W"] * df["Y"] / e_vals - (1 - df["W"]) * df["Y"] / (1 - e_vals))
                            local_ipw_estimates[method]["total_data"].append(ipw_est)

            # Compute AIPW estimates (simplified)
            if run_aipw:
                for method in local_aipw_estimates.keys():
                    if f"hat_e_{method}" in df.columns or method == "oracle":
                        e_col = f"hat_e_{method}" if method != "oracle" else "e_oracle"
                        if e_col in df.columns:
                            # Simple AIPW calculation (without outcome regression for brevity)
                            e_vals = np.clip(df[e_col], 0.01, 0.99)
                            ipw_est = np.mean(df["W"] * df["Y"] / e_vals - (1 - df["W"]) * df["Y"] / (1 - e_vals))
                            local_aipw_estimates[method]["total_data"].append(ipw_est)

        return local_ipw_estimates, local_aipw_estimates, df, {"ipw": estimators_ipw, "aipw": estimators_aipw}

    def plot_aipw(
        self,
        dict_results: pd.DataFrame,
        data_scaling=100,
        tau_star="expectancy_of_tau",
        y_lim: tuple = None,
        title: str = None,
        save_pdf: str = None,
        showmeans: bool = False,
        showfliers=False,
        figsize=None,
        rotation=20,
        labels_list=None,
        print_labels=True,
        generate_by_center=True,
        print_1s_ivw_in_blue=True,
        pattern_for_NN=True,
        figsize_cm=None,
        left=0.16, right=0.98, bottom=0.14, top=0.98,
        meta_hidden=True,
        plot_title_ate=False
    ):
        """Plot IPW/AIPW results."""
        import matplotlib.pyplot as plt
        import seaborn as sns
        
        def _cm_to_in(x):
            return x / 2.54

        # Set plotting parameters
        plt.rcParams.update({
            "font.size": 14,
            "font.weight": "normal",
            "axes.titlesize": 14,
            "axes.labelsize": 18,
            "xtick.labelsize": 18,
            "ytick.labelsize": 14,
            "axes.linewidth": 2.5,
            "lines.linewidth": 2.5,
            "patch.linewidth": 8,
        })
        sns.set_style("whitegrid")
        sns.set_palette("Set2")
        set2_colors = sns.color_palette("Set2")

        # Calculate true ATE
        if tau_star == "expectancy_of_tau":
            tau_star = (
                self.client_params_dict["client1"]["beta_1"][0]
                - self.client_params_dict["client1"]["beta_0"][0]
                + np.average(
                    [
                        self.client_params_dict[client]["mean_covariates"]
                        @ (
                            np.array(self.client_params_dict[client]["beta_1"][1:])
                            - np.array(self.client_params_dict[client]["beta_0"][1:])
                        ).T
                        for client in self.clients_list
                    ],
                    weights=[self.client_params_dict[client]["sample_size"] for client in self.clients_list],
                )
            )
        print(f"True ATE: {tau_star}")

        # Create color palette
        custom_palette = {}
        for estimator in dict_results.keys():
            if ("pool" in estimator) or ("Oracle" in estimator):
                custom_palette[estimator] = set2_colors[1]  # orange
            elif "MW" in estimator:
                custom_palette[estimator] = set2_colors[0]  # green
            elif "Meta" in estimator:
                custom_palette[estimator] = set2_colors[5]  # yellow
            else:
                custom_palette[estimator] = set2_colors[2]  # blue

        # Set figure size
        if figsize_cm is not None:
            figsize = (_cm_to_in(figsize_cm[0]), _cm_to_in(figsize_cm[1]))
        elif figsize is None:
            figsize = (12, 8)

        # Create plot
        fig, ax = plt.subplots(figsize=figsize)
        
        # Prepare data for plotting
        data_to_plot = []
        labels = []
        colors = []
        
        for estimator, results in dict_results.items():
            if "total_data" in results and results["total_data"]:
                data_to_plot.append(results["total_data"])
                labels.append(estimator)
                colors.append(custom_palette.get(estimator, set2_colors[0]))

        if not data_to_plot:
            print("No data to plot")
            return fig

        # Create boxplot
        bp = ax.boxplot(data_to_plot, patch_artist=True, showmeans=showmeans, showfliers=showfliers)
        
        # Apply colors
        for patch, color in zip(bp['boxes'], colors):
            patch.set_facecolor(color)
            patch.set_alpha(0.7)

        # Add true ATE line
        if tau_star is not None:
            ax.axhline(y=tau_star, color='red', linestyle='--', linewidth=2, label=f'True ATE: {tau_star:.3f}')

        # Customize plot
        ax.set_xticklabels(labels, rotation=rotation)
        ax.set_ylabel('ATE Estimate')
        if title:
            ax.set_title(title)
        if y_lim:
            ax.set_ylim(y_lim)
        
        ax.legend()
        ax.grid(True, alpha=0.3)
        
        # Adjust layout
        plt.subplots_adjust(left=left, right=right, bottom=bottom, top=top)
        
        if save_pdf:
            plt.savefig(save_pdf, format='pdf', dpi=300, bbox_inches='tight')
        
        plt.tight_layout()
        return fig