import torch 
from torch import nn
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import SplineTransformer, StandardScaler
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score

def customized_knots(X, boundary_knots, n_knots, degree, unif = False):
    lwbnd, upbnd, Xstd = X.min(), X.max(), X.std()

    if unif:
        return np.linspace(lwbnd, upbnd, n_knots).reshape(-1,1)
        
    lw_knots = np.linspace(lwbnd - 0.1*Xstd, lwbnd + 0.1*Xstd, boundary_knots)
    mid_knots = np.linspace(lwbnd + 0.1*Xstd, upbnd - 0.1*Xstd, n_knots - boundary_knots)
    up_knots = np.linspace(upbnd - 0.1*Xstd , upbnd + 0.1*Xstd, boundary_knots)
    knots = np.unique(np.concatenate([lw_knots, mid_knots, up_knots])).reshape(-1, 1)
    
    return knots


def spline_transform(X, n_knots = None, boundary_knots = None, degree = 3, custom = False, extrapolation = 'constant'):

    if custom:
        knots = customized_knots(X, boundary_knots, n_knots, degree)
        spline = SplineTransformer(knots=knots, degree=degree, extrapolation = extrapolation)
    else:
        spline = SplineTransformer(n_knots, degree)

            
    X_spline = spline.fit_transform(X.reshape(-1,1))
    return X_spline

def diff_penalty(num_coefs, order=2):
    D = np.diff(np.eye(num_coefs), n=order, axis=0)
    return D


def PS_smoothing_matrix(x, lam, n_knots, degree, boundary_knots, custom):
    B = spline_transform(x, n_knots = n_knots, boundary_knots = boundary_knots, custom = custom)
    D = diff_penalty(B.shape[1], order=2)
    
    S = B @ np.linalg.inv(B.T @ B + lam * D.T @ D) @ B.T
    return S

def SAM(X, y, lam = 0, alpha=0.25, max_iter=10, tol=1e-6, ftol = 1e-3, n_knots=10, boundary_knots = 3, degree=3, custom = False):
    n_samples, n_features = X.size()
    whole_feature = set(list(range(n_features)))
    feature_space = (list(range(n_features)))
    flag = [True] * n_features
    f = torch.zeros((n_samples, n_features))
    R = torch.clone(y)
    
    for _ in range(max_iter):
        f_old = torch.clone(f)
        flag = [True] * len(feature_space)

        df = 0
        for j in range(len(feature_space)):
            u_space_idx = [feature_space[j]]
            res_space_idx = list(whole_feature-set(u_space_idx))
            Res = R - f[:, res_space_idx].sum(axis=1)
            Res = torch.FloatTensor(Res)
            
            PS_matrix = PS_smoothing_matrix(X[:, j], lam = lam, n_knots = n_knots, degree = degree, boundary_knots = boundary_knots, custom = custom)
            PS_matrix = torch.FloatTensor(PS_matrix)

            P_j = PS_matrix @ Res
            s_j = torch.sqrt(torch.mean(P_j**2))
            if s_j > alpha and flag[j]:

                f[:, feature_space[j]] =  (1 - alpha / s_j) * P_j
                if torch.mean(f[:, feature_space[j]]**2, axis = 0) >= ftol:
                    df += torch.trace(PS_matrix)
                else:
                    df += torch.trace(PS_matrix)
            else:
                flag[j] = False
                f[:, feature_space[j]] = 0
            
            del PS_matrix
            
        tfs = []
        for b in range(len(flag)):
            if flag[b]:
                tfs.append(feature_space[b])
        feature_space = tfs

        if (torch.sum(torch.square(f - f_old)) < tol):
            Active_index = torch.where(torch.mean(torch.square(f), axis = 0) >= ftol)[0]
            print(f"Alpha: {alpha:.2f} | Convergence.")

            return f

    print(f"Alpha: {alpha:.2f} | Not Convergence yet.")

    return f


def GCV_loss(alpha, comp, y, df):
    GCV = torch.zeros_like(alpha)
    n_samples = y.size()[0]

    for i in range(len(alpha)-1, -1, -1):
        pred_y = comp[i].sum(axis = 1)
        MSE = torch.sum(torch.square(y-pred_y))/n_samples
    
        if df[i] > n_samples:
            GCV[i] = (GCV[-1])
        else:
            GCV[i] = (MSE/((1-df[i]/n_samples)**2))

    return GCV

def estimate_gcv(alpha, comp, X, y):

    GCV = torch.zeros_like(alpha)
    n_samples = y.size()[0]
    criterion = nn.MSELoss()
    
    for i in range(len(alpha)-1, -1, -1):
        pred_y = comp[i].sum(axis = 1)
        MSE = criterion(y, pred_y)
        
        active_set = torch.where(torch.norm(comp[i], p = 1, dim=0) != 0)[0].tolist()
        effective_df = 0
        for idx in active_set:
            PS_matrix = PS_smoothing_matrix(X[:, idx], lam = 0.1, n_knots = 10, degree = 3, boundary_knots = 3, custom = True)    
            effective_df += (np.trace(PS_matrix)/n_samples)
    
            del PS_matrix

        GCV[i] = MSE + effective_df

    optimalloc_ = torch.where(GCV == torch.min(GCV))[0].item()
    optimalset_ = torch.where(torch.norm(comp[optimalloc_], p = 1, dim=0) != 0)[0].tolist()
    
    return GCV, optimalloc_, optimalset_

def plot_comp_norm(GCV, alpha, comp, maxl, loc):
    
    x_axis = []
    for i in range(len(alpha)):
        x_axis.append(torch.sum(torch.norm(comp[i], dim = 0))/ maxl)

    lines = []
    plt.figure(figsize = (12,4))
    n_samples, n_features = comp[0].size()
    
    for j in range(n_features):
        y_axis_list = torch.zeros_like(alpha)
        for i in range(len(alpha)):
            y_axis_list[i] = (torch.norm(comp[i][:, j], p = 1)/n_samples)
    
        plt.subplot(121)
        line, = plt.plot(x_axis, y_axis_list, linestyle='--', label=f"x {j+1}")
        lines.append((x_axis, y_axis_list, f"x{j+1}"))
    
    for x_vals, y_vals, label in lines:
        plt.text(x_vals[-1] + 1.01, y_vals[0] , label, va='center')
        
    plt.vlines(x_axis[loc], ymin = -0.1, ymax = 2.5,colors='red', linestyles='dashed', label='Vertical Lines')
    plt.ylabel('Component Norms',fontweight='bold')
    plt.subplot(122)
    plt.plot(x_axis, GCV, color = 'b')
    plt.vlines(x_axis[loc], ymin = torch.min(GCV) - 1, ymax = torch.max(GCV) + 0.3,colors='red', linestyles='dashed', label='Vertical Lines')
    plt.ylabel('GCV')
    plt.savefig('./img/EX1-norm.png')

def plot_component(X, true, opt_comp, title):
    true = true + [np.zeros_like(true[0]), np.zeros_like(true[0])]
    plt.figure(figsize=(16,8))
    plt.title(title)
    for i in range(len(true)):
        plt.subplot(231+i)
        sorted_indices = np.argsort(X[:, i])
        X_sort = X[sorted_indices, i]
        func_sort = true[i][sorted_indices]
        comp_sort = opt_comp[sorted_indices, i]
    
        plt.plot(X_sort, func_sort, color = 'black', linestyle='-', label = 'True function')
        plt.plot(X_sort, comp_sort, c='red', linestyle='--', label = 'Estimated function')
        plt.xlabel(f'x{i+1}')
        plt.ylabel(f'f{i+1}')
        plt.legend()
    plt.tight_layout()
    plt.savefig('./img/EX1.png')

def feature_selection(X, active_idx):
    X_train, X_valid, X_test = X[0], X[1], X[2]
    
    # Split data into training, validation and testing sets
    if active_idx != []:
        X_train, X_valid, X_test =  X_train[:, active_idx], X_valid[:, active_idx], X_test[:, active_idx]
        
    # Standardize features
    scaler = StandardScaler()
    X_train = scaler.fit_transform(X_train)
    X_valid = scaler.fit_transform(X_valid)
    X_test = scaler.fit_transform(X_test)
    
    # Convert to PyTorch tensors
    X_train = torch.tensor(X_train, dtype=torch.float32)
    X_valid = torch.tensor(X_valid, dtype=torch.float32)
    X_test = torch.tensor(X_test, dtype=torch.float32)
    
    return [X_train, X_valid, X_test]


def _clustering(optimal_f, active_index, threshold, type):
    """
    Returns:
        dict with either 'main_idx' or 'inter_idx'

    """

    active_f = optimal_f[:, active_index]
    active_norm = torch.norm(active_f, dim = 0)
    feature_stack = torch.column_stack([
        active_norm,
        torch.mean(active_f, dim = 0),
        torch.std(active_f, dim = 0),
        torch.min(active_f, dim = 0).values,
        torch.max(active_f, dim = 0).values
    ])
    
    if type == 'kmean':
        kmeans = KMeans(n_clusters = 2)
        labels = kmeans.fit_predict(feature_stack)
        scores = silhouette_score(feature_stack, labels)
    else:
        pass

    ## Score
    if scores < threshold:
        labels = [0 for i in range(len(active_index))]

    if all(label == 0 for label in labels):
        return {'main_idx': active_index, 'inter_idx': None}
    
    label0_norms = [norm for norm, label in zip(active_norm, labels) if label == 0]
    label1_norms = [norm for norm, label in zip(active_norm, labels) if label == 1]

    sum_label0_norms = sum(label0_norms) if label0_norms else float('-inf')
    sum_label1_norms = sum(label1_norms) if label1_norms else float('-inf')

    main_label = 1 if sum_label1_norms > sum_label0_norms else 0
        
    main_idx = [i for i, (label, norm) in enumerate(zip(labels, active_norm)) if label == main_label]
    inter_idx = [i for i in active_index if i not in main_idx]

    return {'main_idx': main_idx, "inter_idx": [inter_idx]}


def train_SAM(X, y, alpha_list, max_iter, nk, nb, custom):

    Max_L1 = torch.zeros((len(alpha_list)))
    component_list = {}
    result = {}

    for i in range(len(alpha_list)):
        f = SAM(X, y, lam = 0.1, alpha = alpha_list[i], max_iter = max_iter, tol=1e-6, ftol = 1e-3, n_knots=nk, boundary_knots = nb, degree=3, custom = custom)
        # Identify the non-active features among iteration
        nonact_idx = torch.where(torch.sum(torch.square(f), axis = 0) == 0)[0]
        Max_L1[i] = torch.sum(torch.norm(f, dim = 0))

        component_list[i] = f
        component_list[i][:, nonact_idx] = 0

    GCV_list, loc, active_dict = estimate_gcv(alpha_list, component_list, X, y)
    Max_L1 = torch.max(Max_L1)
    result['component'] = component_list
    result['GCV'] = GCV_list
    result['opt_loc'] = loc
    result['opt_var'] = active_dict
    result['Max_L1'] = Max_L1
    result['alpha'] = alpha_list
    
    return result

def model_summary(cluster_result):
    FS, SS = cluster_result[0], cluster_result[1]

    MAIN_EFFECTS = set(FS['main_idx'])
    POTENTIAL_MAIN_EFFECTS = set(FS['inter_idx'][0]).difference(set(SS['main_idx']))
    MAIN_EFFECTS = list(MAIN_EFFECTS.union(POTENTIAL_MAIN_EFFECTS))

    GROUP_EFFECTS = set(FS['inter_idx'][0]).intersection(set(SS['main_idx']))
    POTENTIAL_GROUP_EFFECTS = set(SS['inter_idx'][0])
    [[3, 4]] + [[0, 1, 2]]

    # Decide the criteria
    #GROUP_EFFECTS = list(GROUP_EFFECTS.union(POTENTIAL_GROUP_EFFECTS))
    GROUP_EFFECTS = [list(GROUP_EFFECTS)] + [list(POTENTIAL_GROUP_EFFECTS)]


    return {'MAIN': MAIN_EFFECTS, 'GROUP': GROUP_EFFECTS}