import os

import numpy as np
import torch
from torch.utils.data import Dataset

from .augmentations import DataTransform

import os
import numpy as np
import pandas as pd
from tslearn.metrics import dtw 
import tqdm

from sklearn.preprocessing import MinMaxScaler

def get_DTW(UTS_tr):
    N = len(UTS_tr)
    DTW_matrix = np.zeros((N,N))
    for i in tqdm.tqdm(range(N)):
        for j in range(N):
            if i>j:
                dist = dtw(UTS_tr[i].reshape(-1,1), UTS_tr[j].reshape(-1,1))
                DTW_matrix[i,j] = dist
                DTW_matrix[j,i] = dist
            elif i==j:
                DTW_matrix[i,j] = 0
            else :
                pass
    return DTW_matrix

def get_MDTW(MTS_tr):
    N = MTS_tr.shape[0]
    DTW_matrix = np.zeros((N,N))
    for i in tqdm.tqdm(range(N)):
        for j in range(N):
            if i>j:
                mdtw_dist = dtw(MTS_tr[i], MTS_tr[j])
                DTW_matrix[i,j] = mdtw_dist
                DTW_matrix[j,i] = mdtw_dist
            elif i==j:
                DTW_matrix[i,j] = 0
            else :
                pass
    return DTW_matrix

def save_dtw_similarity(X_tr, min_ = 0, max_ = 1, multivariate=False):
    if multivariate:
        DTW_dist = get_MDTW(X_tr)
    else:
        DTW_dist = get_DTW(X_tr)
        
    diag_indices = np.diag_indices(DTW_dist.shape[0])
    mask = np.ones(DTW_dist.shape, dtype=bool)
    mask[diag_indices] = False
    temp = DTW_dist[mask].reshape(DTW_dist.shape[0], DTW_dist.shape[1]-1)
    diag_indices = np.diag_indices(DTW_dist.shape[0])
    DTW_dist[diag_indices] = temp.min()
    scaler = MinMaxScaler(feature_range=(min_, max_))
    
    DTW_dist_scaled = scaler.fit_transform(DTW_dist)
    DTW_sim = 1 - DTW_dist_scaled 
    return DTW_sim 

class Load_Dataset(Dataset):
    # Initialize your data, download, etc.
    def __init__(self, dataset, config, training_mode):
        super(Load_Dataset, self).__init__()
        self.training_mode = training_mode

        X_train = dataset["samples"]
        y_train = dataset["labels"]

        if len(X_train.shape) < 3:
            X_train = X_train.unsqueeze(2)

        if X_train.shape.index(min(X_train.shape)) != 1:  # make sure the Channels in second dim
            X_train = X_train.permute(0, 2, 1)

        if isinstance(X_train, np.ndarray):
            self.x_data = torch.from_numpy(X_train)
            self.y_data = torch.from_numpy(y_train).long()
        else:
            self.x_data = X_train
            self.y_data = y_train

        self.len = X_train.shape[0]
        if training_mode == "self_supervised" or training_mode == "SupCon":  # no need to apply Augmentations in other modes
            self.aug1, self.aug2 = DataTransform(self.x_data, config)

    def __getitem__(self, index):
        if self.training_mode == "self_supervised" or self.training_mode == "SupCon":
            return index, self.x_data[index], self.y_data[index], self.aug1[index], self.aug2[index]
        else:
            return index, self.x_data[index], self.y_data[index], self.x_data[index], self.x_data[index]

    def __len__(self):
        return self.len


def data_generator(data_path, configs, training_mode, pc, batch_size):
    #batch_size = configs.batch_size

    if training_mode != "SupCon":
        if ('ft' in training_mode) & (pc == 1):
            print('1%')
            train_dataset = torch.load(os.path.join(data_path, "train_1perc.pt"))
        elif ('ft' in training_mode) & (pc == 5):
            print('5%')
            train_dataset = torch.load(os.path.join(data_path, "train_5perc.pt"))
        elif ('ft' in training_mode) & (pc == 10):
            print('10%')
            train_dataset = torch.load(os.path.join(data_path, "train_10perc.pt"))
        elif ('ft' in training_mode) & (pc == 50):
            print('50%')
            train_dataset = torch.load(os.path.join(data_path, "train_50perc.pt"))
        elif ('ft' in training_mode) & (pc == 75):
            print('75%')
            train_dataset = torch.load(os.path.join(data_path, "train_75perc.pt"))
        else:
            train_dataset = torch.load(os.path.join(data_path, "train.pt"))
    else :
        train_dataset = torch.load(os.path.join(data_path, f"pseudo_train_data_{str(pc)}perc.pt"))
    
    valid_dataset = torch.load(os.path.join(data_path, "val.pt"))
    test_dataset = torch.load(os.path.join(data_path, "test.pt"))
    print(train_dataset['samples'].shape)
    print(valid_dataset['samples'].shape)
    print(test_dataset['samples'].shape)
    train_dataset = Load_Dataset(train_dataset, configs, training_mode)
    valid_dataset = Load_Dataset(valid_dataset, configs, training_mode)
    test_dataset = Load_Dataset(test_dataset, configs, training_mode)

    if batch_size == 999:
        if train_dataset.__len__() < batch_size:
            if train_dataset.__len__() > 16:
                batch_size = 16
            else:
                batch_size = 4
        else:
            batch_size = configs.batch_size

    print('batch_size',batch_size)
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size,
                                               shuffle=True, drop_last=configs.drop_last, num_workers=0)
    valid_loader = torch.utils.data.DataLoader(dataset=valid_dataset, batch_size=batch_size,
                                               shuffle=False, drop_last=configs.drop_last, num_workers=0)

    test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size,
                                              shuffle=False, drop_last=False, num_workers=0)

    return train_loader, valid_loader, test_loader


def data_generator_wo_val(data_path, configs, training_mode, pc, batch_size):
    #batch_size = configs.batch_size

    if training_mode != "SupCon":
        if ('ft' in training_mode) & (pc == 1):
            print('1%')
            train_dataset = torch.load(os.path.join(data_path, "train_1perc.pt"))
        elif ('ft' in training_mode) & (pc == 5):
            print('5%')
            train_dataset = torch.load(os.path.join(data_path, "train_5perc.pt"))
        elif ('ft' in training_mode) & (pc == 10):
            print('10%')
            train_dataset = torch.load(os.path.join(data_path, "train_10perc.pt"))
        elif ('ft' in training_mode) & (pc == 50):
            print('50%')
            train_dataset = torch.load(os.path.join(data_path, "train_50perc.pt"))
        elif ('ft' in training_mode) & (pc == 75):
            print('75%')
            train_dataset = torch.load(os.path.join(data_path, "train_75perc.pt"))
        else:
            train_dataset = torch.load(os.path.join(data_path, "train.pt"))
    else :
        train_dataset = torch.load(os.path.join(data_path, f"pseudo_train_data_{str(pc)}perc.pt"))
    
#    valid_dataset = torch.load(os.path.join(data_path, "val.pt"))
    test_dataset = torch.load(os.path.join(data_path, "test.pt"))
    train_dataset['samples'] = train_dataset['samples']#[:1000,:,:]
    test_dataset['samples'] = test_dataset['samples']#[:1000,:,:]
    train_dataset['labels'] = train_dataset['labels']#[:1000]
    test_dataset['labels'] = test_dataset['labels']#[:1000]
    print(train_dataset['samples'].shape)
#    print(valid_dataset['samples'].shape)
    print(test_dataset['samples'].shape)
    
    train_dataset = Load_Dataset(train_dataset, configs, training_mode)
    #valid_dataset = Load_Dataset(valid_dataset, configs, training_mode)
    test_dataset = Load_Dataset(test_dataset, configs, training_mode)

    
    if train_dataset.__len__() < configs.batch_size:
        if train_dataset.__len__() > 16:
            batch_size = 16
        else:
            batch_size = 4
    else:
        batch_size = configs.batch_size
            
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size,
                                               shuffle=True, drop_last=configs.drop_last, num_workers=0)
    #valid_loader = torch.utils.data.DataLoader(dataset=valid_dataset, batch_size=batch_size,
    #                                           shuffle=False, drop_last=configs.drop_last, num_workers=0)

    test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size,
                                              shuffle=False, drop_last=False, num_workers=0)

    #return train_loader, valid_loader, test_loader
    return train_loader,test_loader

def set_nan_to_zero(a):
    where_are_NaNs = np.isnan(a)
    a[where_are_NaNs] = 0
    return a


def preprocess_TS(TS):
    TS = set_nan_to_zero(TS)
    if TS.ndim == 2:
        print('Preprocessing Univariate Time Series ...')
        TS_max = TS.max(axis = 1) + (1e-6)
        TS_min = TS.min(axis = 1)
        TS = (TS - TS_min.reshape(-1,1))/(TS_max.reshape(-1,1) - TS_min.reshape(-1,1))        
    elif TS.ndim == 3:
        print('Preprocessing Multivariate Time Series ...')
        N, D, L = TS.shape
        TS_max = TS.max(axis=2).reshape(N,D,1) + (1e-6)
        TS_min = TS.min(axis=2).reshape(N,D,1)
        TS = (TS - TS_min) / (TS_max - TS_min)   
    return TS


def gen_semiCLS_DTW(dataset):
    ##################################################################################
    ## LOAD DATASET
    ##################################################################################
    tr = torch.load(f'data/{dataset}/train.pt')
    ##tr_perc1 = torch.load(f'data/{dataset}/train_1perc.pt')
    ##tr_perc5 = torch.load(f'data/{dataset}/train_5perc.pt')
    ##ts = torch.load(f'data/{dataset}/test.pt')
    
    # (1) Train (100%)
    train = tr['samples'].detach().cpu().numpy().astype(np.float64)
    ##train_labels = tr['labels'].detach().cpu().numpy()
    
    # (2) Train (1%)
    ##train_perc1 = tr_perc1['samples'].detach().cpu().numpy().astype(np.float64)
    ##train_perc1_labels = tr_perc1['labels'].detach().cpu().numpy()
    
    # (3) Train (5%)
    ##train_perc5 = tr_perc5['samples'].detach().cpu().numpy().astype(np.float64)
    ##train_perc5_labels = tr_perc5['labels'].detach().cpu().numpy()
    
    # (4) Test (100%)
    ##test = ts['samples'].detach().cpu().numpy().astype(np.float64)
    ##test_labels = ts['labels'].detach().cpu().numpy()
    
    ##################################################################################
    ## ONE-HOT ENCODING
    ##################################################################################
    '''
    labels = np.unique(train_labels)
    transform = {}
    for i, l in enumerate(labels):
        transform[l] = i

    train_labels = np.vectorize(transform.get)(train_labels)
    train_perc1_labels = np.vectorize(transform.get)(train_perc1_labels)
    train_perc5_labels = np.vectorize(transform.get)(train_perc5_labels)
    test_labels = np.vectorize(transform.get)(test_labels)
    '''
    ##################################################################################
    ## RESHAPE
    ##################################################################################
    if (train.ndim==3) & (train.shape[1]==1):
        train=train.squeeze(1)
        ##train_perc1=train_perc1.squeeze(1)
        ##train_perc5=train_perc5.squeeze(1)
        ##test=test.squeeze(1)
    elif (train.ndim==3) & (train.shape[1]>1):
        train=train.transpose(0,2,1)
        ##train_perc1=train_perc1.transpose(0,2,1)
        ##train_perc5=train_perc5.transpose(0,2,1)
        ##test=test.transpose(0,2,1)  
    
    ##################################################################################
    ## DTW PAIRWISE MATRIX
    ##################################################################################
    
    os.makedirs(f'data/{dataset}', exist_ok=True)
    DTW_file = os.path.join(f'data/{dataset}',f'DTW.npy')
        
    if os.path.exists(DTW_file):
        print("DTW already exists")
        DTW = np.load(DTW_file)
    else:
        print("Saving DTW ...")
        if train.ndim==2:
            DTW = save_dtw_similarity(preprocess_TS(train), min_ = 0, max_ = 1)
            np.save(DTW_file, DTW)
        else:
            DTW = save_dtw_similarity(preprocess_TS(train), min_ = 0, max_ = 1)
            np.save(DTW_file, DTW)
    
    return DTW