import os
import sys
import pandas as pd
import torch 
import numpy as np 
import torch.nn.functional as F

from tqdm import tqdm
from src.utils.data_loader import DataLoader
from src.utils.scores_loader import ScoresLoader
from torch.utils.data import Dataset
from sklearn.preprocessing import PolynomialFeatures

from torch.nn.utils.rnn import pad_sequence

path = os.getcwd()

if path not in sys.path:
    sys.path.append(path)
    

def featurize_csv(path_features:str, degree:int, bias:bool):
    """ add the interactions of the initial feature dataset, check si il existe pas déjà, sinon le créer dans un dossier features"""
    
    path_features_poly = path_features.split(".csv")[0] + f"_deg_{degree}_bias_{int(bias)}.csv" # .split(".")[0] to remove the .csv
    poly = PolynomialFeatures(degree=degree, include_bias=bias)
    
    if not os.path.exists(path_features_poly):
        print("Creating polynomial features because the file does not exist yet")
        df_feat = pd.read_csv(path_features)
        df_feat_name = df_feat.iloc[:, [0]] # keep the first column

        df_feat_val = df_feat.iloc[:, 1:]
        feat_array = df_feat_val.values # get the features to turn into poly features

        # computing missing_cols
        feat_array_copy = feat_array.copy()
        feat_array_copy[0,:] = np.array([feat_array_copy[0,:]]) # reshape for fit_transform
        new_feat = poly.fit_transform(feat_array_copy[0,:].reshape(1,-1))[0]
        missing_cols = len(new_feat) - feat_array_copy.shape[1]

        for i in range(missing_cols):
            df_feat_val[f'extra_{i+1}'] = 0  
        
        for k in tqdm(range(feat_array.shape[0])):
            feat_array[k,:] = np.array([feat_array[k,:]]) # reshape for fit_transform
            new_feat = poly.fit_transform(feat_array[k,:].reshape(1,-1))[0] 

            df_feat_val.loc[k] = new_feat
            df_feat = pd.concat([df_feat_name, df_feat_val], axis=1)
        
        df_feat.to_csv(path_features_poly, index=False)

        print(df_feat.shape)
    return path_features_poly



class TimeseriesFeaturesDataset(Dataset):
    def __init__(self, model:str, path_features:str, path_data:str, path_scores:str, device, testing:bool):

        self.feat= pd.read_csv(path_features)
        self.path_data = path_data
        self.path_scores = path_scores
        self.device = device

        if testing :
            self.feat = self.feat.iloc[:2,:]

    def __len__(self):
        return len(self.feat)

    def __getitem__(self, idx):
        ts_file_name = [self.feat.iloc[idx, 0]]
        _, label, _ = DataLoader(self.path_data).load_timeseries(ts_file_name) #load_timeseries wantsa list of file names
        scores, idx_failed = ScoresLoader(self.path_scores).load(ts_file_name)

        features_array = np.array(self.feat.iloc[idx, 1:].values, dtype=np.float32)  

        features = torch.tensor(features_array, dtype=torch.float32, requires_grad=False).to(self.device)
        scores = torch.tensor(scores[0], dtype=torch.float32, requires_grad=False).to(self.device)
        label = torch.tensor(label[0], dtype=torch.float32, requires_grad=False).to(self.device)

        return features, scores, label

    def collate_fn(batch):
        feat_batch = [item[0] for item in batch]
        scor_batch = [item[1] for item in batch]
        lab_batch = [item[2] for item in batch]

        # Padding des `feat` pour uniformiser la séquence
        feat_batch = pad_sequence(feat_batch, batch_first=True, padding_value=0)

        max_len_scor = max([x.size(0) for x in scor_batch])
        max_len_lab = max([x.size(0) for x in lab_batch])
        max_len = max(max_len_scor, max_len_lab)

        def pad_1d_tensor(tensor, max_len):
            pad_size = max_len - tensor.size(0)  

            if pad_size > 0:
                return F.pad(tensor, (0, pad_size), "constant", 0)  
            return tensor  

        def pad_tensor(tensor, max_len):
            pad_size = max_len - tensor.size(0)
            return F.pad(tensor, (0, 0, 0, pad_size), "constant", 0)  

        scor_batch = torch.stack([pad_tensor(x, max_len) for x in scor_batch])
        lab_batch = torch.stack([pad_1d_tensor(x, max_len) for x in lab_batch])

        return feat_batch, scor_batch, lab_batch
    
if __name__ == "__main__":
    pass