import os
import numpy as np
import h5py
import cv2
import matplotlib.pyplot as plt
import numpy as np 
np.set_printoptions(legacy='1.13')
from sklearn.preprocessing import SplineTransformer
from torch import nn
from numpy.polynomial.legendre import legvander
import torch
import torch.optim as optim
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import PolynomialFeatures
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from itertools import combinations
import torch.nn.functional as F
import matplotlib.pyplot as plt
import pandas as pd 
import time
import copy
import argparse
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
import warnings
import numpy as np
import pandas as pd
from sklearn.preprocessing import SplineTransformer, StandardScaler, PolynomialFeatures
from sklearn.model_selection import KFold
from sklearn.linear_model import LassoCV, Lasso
import matplotlib.pyplot as plt
import warnings
import torch
from itertools import combinations
warnings.filterwarnings("ignore")
from group_lasso import GroupLasso
import xgboost as xgb


warnings.filterwarnings("ignore")

def main():

    parser = argparse.ArgumentParser(description="Train a model")
    parser.add_argument('--data_name', default = "", type=str, help='Type of dataset')

    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    
    def rmse(y_true, y_pred):
        '''
        Compute Root Mean Square Percentage Error between two arrays.
        '''
        loss = torch.sqrt(torch.mean(torch.square(((y_true - y_pred) / y_true)), axis=0))
    
        return loss
    
    def customized_knots(X, boundary_knots, n_knots, degree, unif = False):
        lwbnd, upbnd, Xstd = X.min(), X.max(), X.std()
    
        if unif:
            return np.linspace(lwbnd, upbnd, n_knots).reshape(-1,1)
            
        lw_knots = np.linspace(lwbnd - 0.1*Xstd, lwbnd + 0.1*Xstd, boundary_knots)
        mid_knots = np.linspace(lwbnd + 0.1*Xstd, upbnd - 0.1*Xstd, n_knots - boundary_knots)
        up_knots = np.linspace(upbnd - 0.1*Xstd , upbnd + 0.1*Xstd, boundary_knots)
        knots = np.unique(np.concatenate([lw_knots, mid_knots, up_knots])).reshape(-1, 1)
        
        return knots
    
    
    def spline_transform(X, n_knots = None, boundary_knots = None, degree = 3, custom = False, extrapolation = 'constant'):
    
        if custom:
            knots = customized_knots(X, boundary_knots, n_knots, degree)
            spline = SplineTransformer(knots=knots, degree=degree, extrapolation = extrapolation)
        else:
            spline = SplineTransformer(n_knots, degree)
    
                
        X_spline = spline.fit_transform(X.reshape(-1,1))
        return X_spline
    
    def diff_penalty(num_coefs, order=2):
        D = np.diff(np.eye(num_coefs), n=order, axis=0)
        return D
    
    
    def PS_smoothing_matrix(x, lam, n_knots, degree, boundary_knots, custom):
        B = spline_transform(x, n_knots = n_knots, boundary_knots = boundary_knots, custom = custom)
        D = diff_penalty(B.shape[1], order=2)
        
        S = B @ np.linalg.inv(B.T @ B + lam * D.T @ D) @ B.T
        return S
    
    def SAM(X, y, lam = 0, alpha=0.25, max_iter=10, tol=1e-6, ftol = 1e-3, n_knots=10, boundary_knots = 3, degree=3, custom = False):
        n_samples, n_features = X.size()
        whole_feature = set(list(range(n_features)))
        feature_space = (list(range(n_features)))
        flag = [True] * n_features
        f = torch.zeros((n_samples, n_features))
        R = torch.clone(y)
        
        for _ in range(max_iter):
            f_old = torch.clone(f)
            flag = [True] * len(feature_space)
    
            df = 0
            for j in range(len(feature_space)):
                u_space_idx = [feature_space[j]]
                res_space_idx = list(whole_feature-set(u_space_idx))
                Res = R - f[:, res_space_idx].sum(axis=1)
                Res = torch.FloatTensor(Res)
                
                PS_matrix = PS_smoothing_matrix(X[:, j], lam = lam, n_knots = n_knots, degree = degree, boundary_knots = boundary_knots, custom = custom)
                PS_matrix = torch.FloatTensor(PS_matrix)
    
                P_j = PS_matrix @ Res
                s_j = torch.sqrt(torch.mean(P_j**2))
                if s_j > alpha and flag[j]:
    
                    f[:, feature_space[j]] =  (1 - alpha / s_j) * P_j
                    if torch.mean(f[:, feature_space[j]]**2, axis = 0) >= ftol:
                        df += torch.trace(PS_matrix)
                    else:
                        df += torch.trace(PS_matrix)
                else:
                    flag[j] = False
                    f[:, feature_space[j]] = 0
                
                del PS_matrix
                
            tfs = []
            for b in range(len(flag)):
                if flag[b]:
                    tfs.append(feature_space[b])
            feature_space = tfs
    
            if (torch.sum(torch.square(f - f_old)) < tol):
                print(f"Alpha: {alpha:.2f} | Convergence.")
    
                return f
    
        print(f"Alpha: {alpha:.2f} | Ain't Convergence.")
    
        return f
    
    
    def standardize(y):
        return (y-y.mean())
    
    def f1(x):
        y = -2*torch.sin(2*x)
        return standardize(y)
    
    def f2(x):
        y = x**2/2 + 1
        return standardize(y)
        
    def f3(x):
        y = x - (1/2)
        return standardize(y)
    
    def f4(x1):
        y = torch.exp(-x1) + torch.exp(torch.Tensor([-1])) - 1 
        return standardize(y)
    
    def f5(x1, x2):
        y = torch.exp(torch.sin(x1)+torch.cos(x2)-1)
        return standardize(y)
    
    #def f5(x1, x2):
        #y = torch.sin(x1) * (x2 >= 0) + 0.5 * (x1 * x2)
        #return standardize(y)
    
    def generate_dataset(nsample, nfeature, UB, LB, type):
        
        # n_sample: [training size, validation size, testing size]
        dataset = {}
        name = ['Train', 'Valid', 'Test']
        for i, n in enumerate(nsample):
            data = {}
            X = torch.FloatTensor(n, nfeature).uniform_(LB, UB)
    
            func1 = f1(X[:, 0])
            func2 = f2(X[:, 1])
            func3 = f3(X[:, 2])
    
            if type == 'only_main':
                func4 = f4(X[:, 3])
            elif type == 'weak_main':
                func4 = 0.01 * f4(X[:, 3])
            elif type == 'inter_no_overlap':
                func4 = f5(X[:, 3], X[:, 4])
            elif type == 'inter_mild_overlap':
                func4 = f5(X[:, 2], X[:, 3])
            elif type == 'inter_strong_overlap':
                func4 = f5(X[:, 1], X[:, 2])
            elif type == 'only_inter':
                func1 = f5(X[:, 0], X[:, 1])
                func2 = f5(X[:, 2], X[:, 3])
                func3 = torch.zeros_like(func1)
                func4 = torch.zeros_like(func1)
            else:
                pass
                
    
            
                
            y = func1 + func2 + func3 + func4 + torch.rand(n)
            data['data'] = X
            data['target'] = y
            data['true_func'] = [func1, func2, func3, func4]
            dataset[name[i]] =  data
        
        return dataset
    
    
    def estimate_gcv(alpha, comp, X, y):
    
        GCV = torch.zeros_like(alpha)
        n_samples = y.size()[0]
        criterion = nn.MSELoss()
        
        for i in range(len(alpha)-1, -1, -1):
            pred_y = comp[i].sum(axis = 1)
            MSE = criterion(y, pred_y)
            
            active_set = torch.where(torch.norm(comp[i], p = 1, dim=0) != 0)[0].tolist()
            effective_df = 0
            for idx in active_set:
                PS_matrix = PS_smoothing_matrix(X[:, idx], lam = 0.1, n_knots = 10, degree = 3, boundary_knots = 3, custom = False)    
                effective_df += (np.trace(PS_matrix)/n_samples)
        
                del PS_matrix
    
            # Condition for over-fitting
            GCV[i] = MSE + effective_df
        
        minGCV = torch.where(GCV == torch.min(GCV))[0]
        optimalloc_ = minGCV.item() if len(minGCV) == 1 else minGCV[0].item()
        optimalset_ = torch.where(torch.norm(comp[optimalloc_], p = 1, dim=0) != 0)[0].tolist()
        
        return GCV, optimalloc_, optimalset_
        
    def plot_comp_norm(GCV, alpha, comp, maxl, loc):
        
        x_axis = []
        for i in range(len(alpha)):
            x_axis.append(torch.sum(torch.norm(comp[i], dim = 0))/ maxl)
    
        lines = []
        plt.figure(figsize = (12,4))
        n_samples, n_features = comp[0].size()
        
        for j in range(n_features):
            y_axis_list = torch.zeros_like(alpha)
            for i in range(len(alpha)):
                y_axis_list[i] = (torch.norm(comp[i][:, j], p = 1)/n_samples)
        
            plt.subplot(121)
            line, = plt.plot(x_axis, y_axis_list, linestyle='--', label=f"x {j+1}")
            lines.append((x_axis, y_axis_list, f"x{j+1}"))
        
        for x_vals, y_vals, label in lines:
            plt.text(x_vals[-1] + 1.01, y_vals[0] , label, va='center')
            
        plt.vlines(x_axis[loc], ymin = -0.1, ymax = 2.5,colors='red', linestyles='dashed', label='Vertical Lines')
        plt.ylabel('Component Norms',fontweight='bold')
        plt.subplot(122)
        plt.plot(x_axis, GCV, color = 'b')
        plt.vlines(x_axis[loc], ymin = torch.min(GCV) - 1, ymax = torch.max(GCV) + 0.3,colors='red', linestyles='dashed', label='Vertical Lines')
        plt.ylabel('GCV')
    
    def plot_component(X, y, true, opt_comp, title):
    
        true = true + [np.zeros_like(true[0]), np.zeros_like(true[0])]
        plt.figure(figsize=(16,8))
        plt.title(title)
        for i in range(len(true)):
            plt.subplot(231+i)
            sorted_indices = np.argsort(X[:, i])
            X_sort = X[sorted_indices, i]
            func_sort = true[i][sorted_indices]
            comp_sort = opt_comp[sorted_indices, i]
        
            plt.plot(X_sort, func_sort, color = 'black', linestyle='-', label = 'True function')
            plt.plot(X_sort, comp_sort, c='red', linestyle='--', label = 'Estimated function')
            plt.xlabel(f'x{i+1}')
            plt.ylabel(f'f{i+1}')
            plt.legend()
        plt.tight_layout()
        #plt.savefig('./img/EX1.png')
        plt.show()
    
    def train_SAM(X, y, alpha_list, max_iter, nk, nb, custom):
    
        Max_L1 = torch.zeros((len(alpha_list)))
        component_list = {}
        result = {}
    
        for i in range(len(alpha_list)):
            f = SAM(X, y, lam = 0.1, alpha = alpha_list[i], max_iter = max_iter, tol=1e-6, ftol = 1e-3, n_knots=nk, boundary_knots = nb, degree=3, custom = custom)
            # Identify the non-active features among iteration
            nonact_idx = torch.where(torch.sum(torch.square(f), axis = 0) == 0)[0]
            Max_L1[i] = torch.sum(torch.norm(f, dim = 0))
    
            component_list[i] = f
            component_list[i][:, nonact_idx] = 0
    
        GCV_list, loc, active_dict = estimate_gcv(alpha_list, component_list, X, y)
        Max_L1 = torch.max(Max_L1)
        result['component'] = component_list
        result['GCV'] = GCV_list
        result['opt_loc'] = loc
        result['opt_var'] = active_dict
        result['Max_L1'] = Max_L1
        result['alpha'] = alpha_list
        
        return result
    
    def estimate_gcv(alpha, comp, X, y):
    
        def estimate_sigma2(y, y_pred, edf_total):
            n = len(y)
            rss = torch.sum(torch.square(y - y_pred))
            if edf_total > n:
                return rss
            else:
                return rss / (n - edf_total)
    
        def compute_cp(y, y_pred, edf_total, sigma2_hat):
            n = len(y)
            rss = torch.sum(torch.square(y - y_pred))
            cp = (rss / n) + 2 * (sigma2_hat / n) * edf_total
            return cp
    
        CP = torch.zeros_like(alpha)
        n_samples = y.size()[0]
        Critical_point = None
        
        for i in range(len(alpha)-1, -1, -1):
            pred_y = comp[i].sum(axis = 1)
            
            active_set = torch.where(torch.norm(comp[i], p = 1, dim=0) != 0)[0].tolist()
            effective_df = 0
            for idx in range(len(active_set)):
                PS_matrix = PS_smoothing_matrix(X[:, active_set[idx]], lam = 0.1, n_knots = 10, degree = 3, boundary_knots = 3, custom = False)    
                #effective_df += (np.trace(PS_matrix)/n_samples)
                effective_df += np.trace(PS_matrix)
    
                del PS_matrix
    
            rss = torch.sum(torch.square(y - pred_y))
            # Condition for over-fitting
            if effective_df > n_samples:
                sigma2_hat = rss
                if Critical_point == None:
                    Critical_point = i
            else:
                sigma2_hat = estimate_sigma2(y, pred_y, effective_df)
    
            cp_value = compute_cp(y, pred_y, effective_df, sigma2_hat)
            
            CP[i] = cp_value
    
            print(alpha[i], effective_df)
        # Simplify the plot
        CP[:Critical_point] = torch.max(CP[Critical_point:])   
        minCP = torch.where(CP == torch.min(CP))[0]
        optimalloc_ = minCP.item() if len(minCP) == 1 else minCP[0].item()
        optimalset_ = torch.where(torch.norm(comp[optimalloc_], p = 1, dim=0) != 0)[0].tolist()
        
        return CP, optimalloc_, optimalset_
    
    class AdditiveInteractionSelector:
        """
        Fit additive models with candidate main effects and interactions
        using spline basis expansions + group sparsity (group lasso).
        """
    
        def __init__(self, n_splines=10, spline_degree=3, include_intercept=False,
                     interaction_splines=10, random_state=0):
            self.n_splines = n_splines
            self.spline_degree = spline_degree
            self.include_intercept = include_intercept
            self.interaction_splines = interaction_splines
            self.random_state = random_state
    
            # Internal storage
            self.groups = []
            self.group_names = []
            self.scaler = None
            self.model = None
            self.group_norms_ = None
            self.design_matrix_ = None
    
        # -----------------------------
        # Basis construction utilities
        # -----------------------------
        def _build_univariate_basis(self, x):
            """Build spline basis for one variable."""
            x = np.asarray(x).reshape(-1, 1)
            sp = SplineTransformer(
                degree=self.spline_degree,
                n_knots=self.n_splines,
                include_bias=self.include_intercept
            )
            return sp.fit_transform(x)
    
        def _build_bivariate_basis(self, x1, x2):
            """Build tensor product spline basis for interaction."""
            B1 = self._build_univariate_basis(x1)
            B2 = self._build_univariate_basis(x2)
            # Tensor product
            return np.einsum("ij,ik->ijk", B1, B2).reshape(len(x1), -1)
    
        def _build_design(self, X_df, interactions=None):
            """Construct design matrix with groups for univariates and interactions."""
            blocks, self.groups, self.group_names = [], [], []
            col_idx = 0
    
            # Main effects
            for col in X_df.columns:
                B = self._build_univariate_basis(X_df[col].values)
                blocks.append(B)
                self.groups.append(list(range(col_idx, col_idx + B.shape[1])))
                self.group_names.append((col,))
                col_idx += B.shape[1]
    
            # Interactions
            if interactions:
                for a, b in interactions:
                    Bt = self._build_bivariate_basis(X_df[a].values, X_df[b].values)
                    blocks.append(Bt)
                    self.groups.append(list(range(col_idx, col_idx + Bt.shape[1])))
                    self.group_names.append((a, b))
                    col_idx += Bt.shape[1]
    
            self.design_matrix_ = np.hstack(blocks)
            return self.design_matrix_
    
        # -----------------------------
        # Fitting
        # -----------------------------
        def fit(self, X_df, y, interactions=None, cv=5, HAS_GROUP_LASSO=True):
            """
            Fit model with group lasso (preferred)
            """
            X = self._build_design(X_df, interactions)
            self.scaler = StandardScaler()
            Xs = self.scaler.fit_transform(X)
    
            if HAS_GROUP_LASSO:
                # Build group vector
                col_to_group = np.zeros(X.shape[1], dtype=int)
                for gid, idxs in enumerate(self.groups):
                    col_to_group[idxs] = gid
    
                # Cross-validate group lasso penalty
                lambdas = np.logspace(-3, 1, 10)
                best_score, best_model = -np.inf, None
                kf = KFold(n_splits=cv, shuffle=True, random_state=self.random_state)
    
                for lam in lambdas:
                    scores = []
                    for tr, va in kf.split(Xs):
                        gl = GroupLasso(
                            groups=col_to_group,
                            group_reg=lam, l1_reg=0.0,
                            scale_reg="group_size",
                            supress_warning=True,
                            n_iter=2000, tol=1e-3
                        )
                        gl.fit(Xs[tr], y[tr])
                        scores.append(gl.score(Xs[va], y[va]))
                    if np.mean(scores) > best_score:
                        best_score = np.mean(scores)
                        best_model = GroupLasso(
                            groups=col_to_group,
                            group_reg=lam, l1_reg=0.0,
                            scale_reg="group_size",
                            supress_warning=True,
                            n_iter=2000, tol=1e-3
                        )
                        best_model.fit(Xs, y)
    
                self.model = best_model
                coefs = self.model.coef_.ravel()
    
            else:
                # Fallback to plain Lasso
                lasso = LassoCV(cv=cv).fit(Xs, y)
                self.model = lasso
                coefs = lasso.coef_
    
            # Compute group norms
            self.group_norms_ = [
                np.linalg.norm(coefs[idxs], ord=2) for idxs in self.groups
            ]
            return self
    
        # -----------------------------
        # Reporting
        # -----------------------------
        def get_group_importance(self):
            """Return DataFrame of group names and their norms."""
            return pd.DataFrame({
                "group": self.group_names,
                "norm": self.group_norms_
            }).sort_values("norm", ascending=False).reset_index(drop=True)
    
        def get_important_groups(self, threshold=0.1):
            """Return groups with norms above threshold."""
            selected = []
            for tup, val in zip(self.group_names, self.group_norms_):
                if val > threshold:
                    indices = [int(s[1:]) - 1 for s in tup]  # convert "x1" → 0
                    selected.append(indices)
            return selected
        
        def summary(self):
            """Print ranked group importance."""
            df = self.get_group_importance()
            print("Group importance (higher = more important):")
            print(df)
    
    
    def extract_active_features(X: torch.tensor, active_idx: list[int]) -> pd.DataFrame:
        """
        Extract active features from X based on active indices.
        
        Parameters
        ----------
        X : np.ndarray
            Data matrix of shape (n_samples, n_features)
        active_idx : list[int]
            Indices of active features
        
        Returns
        -------
        pd.DataFrame
            DataFrame with columns named x1, x2, ..., for active features
        """
        data = {f"x{idx+1}": X[:, idx] for idx in active_idx}
        return pd.DataFrame(data)

    def normalize(x):
        x_min = x.min(dim=0, keepdim=True).values  # 每欄最小值，維持維度
        x_max = x.max(dim=0, keepdim=True).values  # 每欄最大值，維持維度
    
        x_normalized = (x - x_min) / (x_max - x_min)
    
        return x_normalized
    

    
    with h5py.File('../data/EstimatedResponses.mat', 'r') as f:
        # You can inspect the contents of the file
        # For example, to list all top-level keys (variables)
        print("Keys in the .mat file:", list(f.keys()))
    
        # To access a specific variable, for example, named 'data_variable'
        # Data will typically be stored as an HDF5 dataset
        # You might need to adjust the key based on your file's structure
        YTrainS1 = f['dataTrnS1'][:] # The [:] loads the entire dataset into memory as a NumPy array
        YTrainS2 = f['dataTrnS2'][:] # The [:] loads the entire dataset into memory as a NumPy array
        YValS1 = f['dataValS1'][:] # The [:] loads the entire dataset into memory as a NumPy array
        YValS2 = f['dataValS2'][:] # The [:] loads the entire dataset into memory as a NumPy array
        voxS1 = f['voxIdxS1'][:] # The [:] loads the entire dataset into memory as a NumPy array
        voxS2 = f['voxIdxS2'][:] # The [:] loads the entire dataset into memory as a NumPy array
        roiS1 = f['roiS1'][:] # The [:] loads the entire dataset into memory as a NumPy array = f['roiS1'][:] # The [:] loads the entire dataset into memory as a NumPy array
        roiS2 = f['roiS2'][:] 
    
    y_trn = YTrainS1[:, np.where(roiS1 == 1)[1]]
    y_trn = torch.tensor(y_trn)
    y_val = YTrainS1[:, np.where(roiS1 == 1)[1]]
    y_val = torch.tensor(y_val)
    
    X_trn = np.load('../data//complex_cell_train.npy')
    X_trn = torch.tensor(X_trn).to(torch.float32)
    X_val = np.load('../data//complex_cell_valid.npy')
    X_val = torch.tensor(X_val).to(torch.float32)


    sub_sample = 800; v1_idx = 810; sub_feature = 500
    X_sub = X_trn[:sub_sample, :sub_feature]
    y_sub = y_trn[:sub_sample, v1_idx].to(torch.float32).view(-1, 1)

    top_n_original_indices = torch.tensor(np.array([  1,   2,   3,   4,  90,  91,  92,  96,  97, 101, 102, 110, 150, 155,
        164, 170, 172, 177, 187, 189, 190, 191, 199, 201, 206, 214, 217, 228,
        237, 268, 284, 314, 341, 447, 456, 457, 490]))
    
    X_sub = normalize(X_sub)
    opt_df = extract_active_features(X_sub, top_n_original_indices)
    interactions = list(combinations(list(opt_df.keys()), 2))
    selector = AdditiveInteractionSelector(n_splines = 5, interaction_splines = 5)
    selector.fit(opt_df, y_sub, interactions=interactions)
    print(selector.get_group_importance())
    SDAM_config = selector.get_important_groups(threshold=1e-3)
    print(SDAM_config)
    
if __name__ == "__main__":
    main()
    