import torch
import torch.optim as optim
from sklearn.linear_model import Ridge
from sklearn.ensemble import AdaBoostRegressor, GradientBoostingRegressor
from sklearn.model_selection import KFold
import pandas as pd
import numpy as np

class ThetaAlphaEstimator:
    def __init__(self, data, T_col, S_col, V_col, model_type='ridge', n_splits=5, predefined_alpha=None):
        # Store column names and data
        self.T_col = T_col
        self.S_col = S_col
        self.V_col = V_col
        self.X_cols = [col for col in data.columns if col not in [T_col, S_col, V_col]]
        
        # Extract columns from DataFrame
        self.data = data
        self.X = data[self.X_cols].values  # Covariates (numpy array for compatibility with sklearn)
        self.T = data[T_col].values        # Treatment
        self.V = data[V_col].values        # Observed outcome
        self.S = data[S_col].values        # Study indicator

        # Store model type and number of splits for cross-fitting
        self.model_type = model_type.lower()
        self.n_splits = n_splits

        # Define constants
        self.sigma_delta2 = data[V_col].loc[data[S_col]==0].var()  # Variance term σ^2_δ
        self.sigma_epsilon2 = data[V_col].loc[data[S_col]==1].var()  # Variance term σ^2_ε

        # Initialize θ_0 and α_0 with gradients enabled for R_b^*; otherwise, only θ_0
        self.theta_0 = torch.tensor(0.1, dtype=torch.float32, requires_grad=True)
        if predefined_alpha is None:
            self.alpha_0 = torch.tensor(0.1, dtype=torch.float32, requires_grad=True)  # α_0 learnable for R_b^*
            self.has_predefined_alpha = False
        else:
            self.alpha_0 = torch.tensor(predefined_alpha, dtype=torch.float32, requires_grad=False)  # Set predefined α_0 for R_a^*
            self.has_predefined_alpha = True

    def _select_model(self):
        # Helper method to select and initialize the model
        if self.model_type == 'ridge':
            return Ridge(alpha=1.0)
        elif self.model_type == 'adaboost':
            return AdaBoostRegressor(n_estimators=50)
        elif self.model_type == 'gradientboosting':
            return GradientBoostingRegressor(n_estimators=100)
        else:
            raise ValueError("Unsupported model type. Choose 'ridge', 'adaboost', or 'gradientboosting'.")

    def cross_fit_models(self):
        # Initialize KFold for cross-fitting
        kf = KFold(n_splits=self.n_splits, shuffle=True, random_state=42)

        # Initialize empty arrays for predictions
        mu_Y_pred = np.zeros(len(self.data))
        mu_W_pred = np.zeros(len(self.data))
        mu_T_pred = np.zeros(len(self.data))

        # Perform K-fold cross-fitting
        for train_index, test_index in kf.split(self.X):
            # Split data into training and test sets for this fold
            X_train, X_test = self.X[train_index], self.X[test_index]
            T_train, T_test = self.T[train_index], self.T[test_index]
            V_train, V_test = self.V[train_index], self.V[test_index]
            S_train, S_test = self.S[train_index], self.S[test_index]

            # Fit models for each conditional expectation on the training set
            model_mu_Y = self._select_model()
            model_mu_W = self._select_model()
            model_mu_T_0 = self._select_model()
            model_mu_T_1 = self._select_model()

            # Fit μ_Y(X,0) on data with S=0
            model_mu_Y.fit(X_train[S_train == 0], V_train[S_train == 0])
            mu_Y_pred[test_index] = model_mu_Y.predict(X_test)

            # Fit μ_W(X,1) on data with S=1
            model_mu_W.fit(X_train[S_train == 1], V_train[S_train == 1])
            mu_W_pred[test_index] = model_mu_W.predict(X_test)

            # Fit μ_T(X,0) and μ_T(X,1) for T given S
            model_mu_T_0.fit(X_train[S_train == 0], T_train[S_train == 0])
            model_mu_T_1.fit(X_train[S_train == 1], T_train[S_train == 1])
            mu_T_pred[test_index] = np.where(S_test == 0, model_mu_T_0.predict(X_test), model_mu_T_1.predict(X_test))

        # Convert predictions to PyTorch tensors
        self.mu_Y_pred = torch.tensor(mu_Y_pred, dtype=torch.float32)
        self.mu_W_pred = torch.tensor(mu_W_pred, dtype=torch.float32)
        self.mu_T_pred = torch.tensor(mu_T_pred, dtype=torch.float32)

    def R_b_star(self):
        # R_b_star ignores any predefined α_0 and uses a learnable α_0
        # if not hasattr(self.alpha_0, 'requires_grad'):
        #     self.alpha_0 = torch.tensor(0.1, dtype=torch.float32, requires_grad=True)  # Reset α_0 to learnable

        # Vectorized computation of R_b^* using cross-fitted predictions
        term1 = (torch.tensor(self.T, dtype=torch.float32) - self.mu_T_pred) * \
                (self.alpha_0 * (torch.tensor(self.V, dtype=torch.float32) - self.mu_W_pred) - self.theta_0 * (torch.tensor(self.T, dtype=torch.float32) - self.mu_T_pred)) * \
                torch.tensor(self.S, dtype=torch.float32)

        term2 = (torch.tensor(self.T, dtype=torch.float32) - self.mu_T_pred) * \
                ((torch.tensor(self.V, dtype=torch.float32) - self.mu_Y_pred) - self.theta_0 * (torch.tensor(self.T, dtype=torch.float32) - self.mu_T_pred)) * \
                (1 - torch.tensor(self.S, dtype=torch.float32)) 

        R_0 = term1 + term2

        R_1 = torch.tensor(self.S, dtype=torch.float32) * \
              (torch.tensor(self.T, dtype=torch.float32) - self.mu_T_pred) * (self.theta_0 / self.alpha_0) * \
              (self.theta_0 * (torch.tensor(self.T, dtype=torch.float32) - self.mu_T_pred) - self.alpha_0 * (torch.tensor(self.V, dtype=torch.float32) - self.mu_W_pred))

        # Stack R_0 and R_1 to form the 2D R_b^* matrix
        R = torch.stack((R_0, R_1), dim=1)
        return R

    def R_a_star(self):
        # R_a_star requires α_0 to be predefined; raise error if not
        if self.has_predefined_alpha==False:
            raise ValueError("α_0 must be predefined for R_a_star. Please provide a predefined α_0 when initializing the class.")
            
        # Computation of R_a^* where α_0 is predefined
        term1 = (torch.tensor(self.T, dtype=torch.float32) - self.mu_T_pred) * \
                (self.alpha_0 * (torch.tensor(self.V, dtype=torch.float32) - self.mu_W_pred) - self.theta_0 * (torch.tensor(self.T, dtype=torch.float32) - self.mu_T_pred)) * \
                torch.tensor(self.S, dtype=torch.float32)

        term2 = (torch.tensor(self.T, dtype=torch.float32) - self.mu_T_pred) * \
                ((torch.tensor(self.V, dtype=torch.float32) - self.mu_Y_pred) - self.theta_0 * (torch.tensor(self.T, dtype=torch.float32) - self.mu_T_pred)) * \
                (1 - torch.tensor(self.S, dtype=torch.float32))

        return term1 + term2

    def R_0_star(self):
        # Computation of R_0^* which depends only on θ_0
        term2 = (torch.tensor(self.T, dtype=torch.float32) - self.mu_T_pred) * \
                ((torch.tensor(self.V, dtype=torch.float32) - self.mu_Y_pred) - self.theta_0 * (torch.tensor(self.T, dtype=torch.float32) - self.mu_T_pred)) * \
                (1 - torch.tensor(self.S, dtype=torch.float32))

        return term2

    def objective(self, R_star_fn):
        # Computes the objective function for any R_star function passed as an argument
        R = R_star_fn()
        expectation = torch.mean(R, dim=0)  # Approximate expectation
        return torch.sum(expectation**2)  # Minimize this to get as close to zero as possible

    def optimize(self, R_star_fn, lr=0.01, num_epochs=1000):
        # Set up the optimizer; only optimize for θ_0 if α_0 is predefined
        params = [self.theta_0] if not hasattr(self.alpha_0, 'requires_grad') else [self.theta_0, self.alpha_0]
        optimizer = optim.Adam(params, lr=lr)

        # Training loop to minimize the objective
        for epoch in range(num_epochs):
            optimizer.zero_grad()  # Zero out the gradients
            loss = self.objective(R_star_fn)  # Calculate the loss
            loss.backward()  # Backpropagate to compute gradients
            optimizer.step()  # Update θ_0 and/or α_0

            # Print the loss every 100 epochs
            # if epoch % 100 == 0:
            #     print(f'Epoch {epoch}, Loss: {loss.item()}, θ_0: {self.theta_0.item()}')
            #     if hasattr(self.alpha_0, 'requires_grad'):
            #         print(f'α_0: {self.alpha_0.item()}')

        # print("Optimization complete.")
        # print("Optimal θ_0:", self.theta_0.item())
        # if hasattr(self.alpha_0, 'requires_grad'):
            # print("Optimal α_0:", self.alpha_0.item())

# Example usage:
# data = pd.DataFrame(...)  # Load your data here
# estimator = ThetaAlphaEstimator(data, T_col='T', S_col='S', V_col='V', model_type='ridge', n_splits=5, predefined_alpha=...)
# estimator.cross_fit_models()  # Perform cross-fitting to get the conditional expectations
# estimator.optimize(estimator.R_b_star)  # Optimize for R_b^*
# estimator.optimize(estimator.R_a_star)  # Optimize for R_a^*
# estimator.optimize(estimator.R_0_star)  # Optimize for R_0^*