from sklearn.preprocessing import PolynomialFeatures, StandardScaler
from itertools import combinations_with_replacement, compress
from lassonet import LassoNetRegressor, LassoNetRegressorCV
from sklearn.preprocessing import PolynomialFeatures
import torch 
import warnings
import argparse
import ast    
import numpy as np

warnings.filterwarnings("ignore")

def main():

    parser = argparse.ArgumentParser(description="Train a model")
    parser.add_argument('--data_name', default = "", type=str, help='Type of dataset')

    def list2set(arr):
        Set = set()
        for item in arr:
            Set.add(item[0]) if len(item) == 1 else Set.add(tuple(item))
    
        return Set
    
    def detection(tm, ti, arr, d):
        trueset = list2set(tm+ti)
        predset = list2set(arr)
    
        TPR = len(trueset & predset)/ len(trueset)
        FPR = len((predset - trueset))/ ((d*(d-1)/2)+d-len(trueset))
        return TPR, FPR

    def main_and_interactions(indices):
        result = []
        # main effects
        result.extend([[i] for i in indices])
        # interactions (with replacement, so includes [i,i])
        result.extend([list(comb) for comb in combinations_with_replacement(indices, 2)])
        return (result)

    args = parser.parse_args()

    name = ['only_main300_data', 'weak_main300_data', 'inter_no_overlap300_data', 'inter_mild_overlap300_data', 'inter_strong_overlap300_data', 'only_inter300_data']
    #name = ['only_main_data', 'weak_main_data', 'inter_no_overlap_data', 'inter_mild_overlap_data', 'inter_strong_overlap_data', 'only_inter_data']

    tml = [[[0], [1], [2], [3]], [[0], [1], [2], [3]], [[0], [1], [2]], [[0], [1], [2]], 
          [[0], [1], [2]], []]
    til = [[], [], [[3, 4]], [[2,3]], [[1,2]], [[1,2], [3,4]]] 

    if args.data_name == 'only_main_data':
        j = 0
    elif args.data_name == 'weak_main_data':
        j = 1
    elif args.data_name == 'inter_no_overlap_data':
        j = 2
    elif args.data_name == 'inter_mild_overlap_data':
        j = 3
    elif args.data_name == 'inter_strong_overlap_data':
        j = 4
    else:
        j = 5
            
    _dict = torch.load('../data/'+ name[j] + '.pt', weights_only= True)
    X_train = np.array(_dict['X_train'])
    y_train = np.array(_dict['y_train'])
        
    TPR = []
    FPR = []

    r = 100
    for i in range(r):
        poly = PolynomialFeatures(degree=2, include_bias=False)
        X_train[i] = (X_train[i] - X_train[i].mean())/X_train.std()
        X_basis = poly.fit_transform(X_train[i])
            
        model = LassoNetRegressorCV(
            hidden_dims=(10, 10),   # neural net architecture
            M=1,                   # hierarchy parameter (linear vs nonlinear strength)
            path_multiplier=2,      # geometric progression along regularization path
            verbose=0,              # verbosity
            patience=10,            # early stopping for training
            batch_size=128,
            torch_seed=42,
        ) 
        path = model.path(X_basis, y_train[i])
        pdim = X_train[i].shape[1]
        indice = [k for k in range(pdim)]
        group_ID = main_and_interactions(indice)
        
        val_loss_list = np.zeros((len(path)))
        for m in range(len(path)):
            val_loss_list[m] = (path[m].val_loss)
        
        min_vloss_idx = np.argmin(val_loss_list).item()
        pred_set = list(compress(group_ID, path[min_vloss_idx].selected))
        tvalue, fvalue = detection(tml[j], til[j], pred_set, pdim)
        TPR.append(tvalue)
        FPR.append(fvalue)
    
    print(f"TPR Mean (DNN): {np.mean(TPR):.4f}, Std: {np.std(TPR):.4f}")
    print(f"FPR Mean (SDAM): {np.mean(FPR):.4f}, Std: {np.std(FPR):.4f}")

if __name__ == "__main__":
    main()