import os
from tqdm import tqdm

import pandas as pd
import numpy as np
import math
import torch.nn.functional as F

from collections import Counter
from sklearn.utils.class_weight import compute_class_weight
from torch.nn.utils.rnn import pad_sequence

from src.utils.data_loader import DataLoader
from src.utils.scores_loader import ScoresLoader
from src.create_split import SplitTrainTest

import torch
from torch.utils.data import Dataset

DICT_LENGTH = {
    "OPPORTUNITY": 22229,
    "IOPS": 7577,
    "SVDB": 230399,
    "Daphnet": 9599,
    "MGAB": 99999,
    "MITDB": 649999,
    # "Occupancy": 7,
    "ECG": 230399,
    "GHL": 200000,
    "SensorScope": 20733,
    "SMD": 23688,
    "KDD21": 6683,
    "NAB": 1880,
    "Genesis": 16219,
    "YAHOO": 740
    
}


class TimeseriesDataset(Dataset):

    def __init__(self, fnames_ts:list, path_data:str, path_raw:str, path_scores:str, window_size:int, device, seed, verbose=True):
        self.fnames_ts = fnames_ts
        self.path_data = path_data
        self.path_raw = path_raw
        self.path_scores = path_scores
        self.device = device
        self.window_size = window_size        
		
        self.labels = []
        self.samples = []
        self.window_ts_name = []

        if len(self.fnames_ts) == 0:
            return
        
        print(len(self.fnames_ts))

        for fname in tqdm(self.fnames_ts, disable=not verbose, desc="Loading all windows"):
            data = pd.read_csv(os.path.join(self.path_raw.format(self.window_size), fname), index_col=False)
            data.iloc[:, 0] = data.iloc[:, 0].apply(lambda name : '.'.join(name.split('.')[:-1])) # remove the number of the window number 'MBA_ECG803_data.out.22' -> 'MBA_ECG803_data.out'
            self.window_ts_name.append(data.iloc[:, 0].to_numpy())
            self.samples.append(data.iloc[:, 2:].to_numpy())
       
        # Concatenate samples and labels
        self.samples = np.concatenate(self.samples, axis=0)
        self.window_ts_name = np.concatenate(self.window_ts_name, axis=0)

        np.random.seed(seed)
        print(f"windows shape : {self.window_ts_name.shape}")
        print(f"samples shape : {self.samples.shape}")        
        indexs = np.random.permutation(len(self.samples[:, 0]))

        # Apply the permutation
        self.samples = self.samples[indexs]
        self.window_ts_name = self.window_ts_name[indexs]

        print(f"shape window_ts_name : {self.window_ts_name.shape}")
        print(f"shape samples : {self.samples.shape}")

        # Add channels dimension. Samples has all the windows of the train or val set
        print("Adding channel dimension to samples")
        self.samples = self.samples[:,np.newaxis, :]
        print(f"self.samples shape : {self.samples.shape}")

        # Load scores and labels with ts_names
        print("loading scores and labels")
        self.load_with_ts_names = []
        self.ts_names = []
        self.dataset_ts_name_list = []
        for k in tqdm(range(len(self.fnames_ts)), desc="Formatting names for loading scores and labels"):
            self.load_with_ts_names.append(self.fnames_ts[k][:-4])# remove .csv extension exemple : "dataset/ts_name.csv" -> "dataset/ts_name"
            self.ts_names.append(self.load_with_ts_names[k].split("/")[1])# remove .csv extension exemple : "dataset/ts_name.csv" -> "dataset/ts_name"
        
        print(f"ts_names[:3] : {self.ts_names[:3]}")

        print("loading scores from files")
        scores_list = ScoresLoader(self.path_scores).load(self.load_with_ts_names)[0]   
        print("loading labels from files")
        labels_list = DataLoader(self.path_data).load_timeseries(self.load_with_ts_names)[1]

        # Création directe des dicts :
        print("Creating dictionaries for scores and labels")       
        self.dic_ts_to_scores = dict(zip(self.ts_names, scores_list))
        self.dic_ts_to_labels = dict(zip(self.ts_names, labels_list))

		
    def __len__(self):
        return self.samples.shape[0] 

    
    def __getitem__(self, idx):
        ts_window = torch.tensor(self.samples[idx], dtype=torch.float32, requires_grad=False).to(self.device)
        ts_name = self.window_ts_name[idx]
        index_dataset = list(self.dic_ts_to_scores.keys()).index(ts_name)
        dataset = self.load_with_ts_names[index_dataset].split("/")[0]

        # get the scores and labels from the dicts : O(1)
        scores = torch.tensor(self.dic_ts_to_scores[ts_name], dtype=torch.float32, requires_grad=False).to(self.device)
        label = torch.tensor(self.dic_ts_to_labels[ts_name], dtype=torch.float32, requires_grad=False).to(self.device)
        
        # Add a channel dimension to the time series window because in the RL pipeline it won't be added automatically by the pytorch dataloader of the ML pipeline
        ts_window = ts_window[None, :, :]
        print(f"dataset : {dataset}")
        return ts_window, scores, label, dataset


    def collate_fn(batch):
        ts_batch = [item[0] for item in batch]    
        scor_batch = [item[1] for item in batch]
        lab_batch = [item[2] for item in batch]
        max_len_scor = max([x.size(0) for x in scor_batch])
        max_len_lab = max([x.size(0) for x in lab_batch])

        try :
            max_len_scor == max_len_lab
        except:
            ValueError("max_len_scor and max_len_lab are not equal")
        max_len = max_len_scor

        def pad_1d_tensor(tensor, max_len):
            pad_size = max_len - np.max(tensor.shape)    

            if pad_size > 0:
                return F.pad(tensor, (0, pad_size), "constant", 0)
    
            return tensor  

        def pad_tensor(tensor, max_len):
            pad_size = max_len - tensor.shape[0]
            return F.pad(tensor, (0, 0, 0, pad_size), "constant", 0)  

        scor_batch = torch.stack([pad_tensor(x, max_len) for x in scor_batch])
        lab_batch = torch.stack([pad_1d_tensor(x, max_len) for x in lab_batch])
        ts_batch = torch.stack([x for x in ts_batch])
        return ts_batch, scor_batch, lab_batch


def df_shift_to_df_slide(df_shift, window_size):
    '''
    Converts a dataframe of shifting windows to a dataframe of sliding windows'
    Note : the transformation delete the label and the ts name which are not needed in windows
    '''
    values = []
    for k in range(len(df_shift)):
        row_values = df_shift.iloc[k, 2:].values.flatten().tolist()
        values.extend(row_values)

    # Create a single-row DataFrame with all the values
    df_ts = pd.DataFrame([values])

    print(f"df_ts shape : {df_ts.shape}")
    sliding_windows = [df_ts.iloc[[0],k:(k+window_size)] for k in range(0, df_ts.shape[1]-window_size+1)]
    df_slide = pd.concat(sliding_windows, axis=0)

    # df_ts a une seule ligne
    arr = df_ts.values.flatten()  # shape: (total_length,)
    n_windows = len(arr) - window_size + 1  
    arr_slide = []
    for k in tqdm(range(n_windows)):
        arr_slide.append(arr[k:k+window_size])

    # Convertir en DataFrame
    df_slide = pd.DataFrame(arr_slide)
    print(f"df_slide shape : {df_slide.shape}")
    return df_slide

def make_df_slide(ts, window_size):
    """
    -------------
    Idea : Build the df_slide having all the sliding windows of length window_size of ts as rows
    -------------
    Note : the transformation delete the label and the ts name which are not needed in windows
    -------------
    Args:
        - ts : tensor of shape (n_samples,)
        - window_size : int
    ------------
    """
    arr = ts.cpu().numpy().flatten()  # shape: (total_length,)
    n_windows = len(arr) - window_size + 1  
    arr_slide = []
    for k in tqdm(range(n_windows)):
        arr_slide.append(arr[k:k+window_size])

    df_slide = pd.DataFrame(arr_slide)
    print(f"df_slide shape : {df_slide.shape}")
    return df_slide


if __name__ == "__main__" :
    pass
    