import os
import json
import torch
import numpy as np
from torch.utils.data import Dataset
from datetime import datetime
import logging

from Dataset import data_loader

logging.basicConfig(format='%(asctime)s | %(levelname)s : %(message)s', level=logging.INFO)
logger = logging.getLogger(__name__)


def Setup(args):
    """
        Input:
            args: arguments object from argparse
        Returns:
            config: configuration dictionary
    """
    config = args.__dict__  # configuration dictionary
    # Create output directory
    initial_timestamp = datetime.now()
    
    output_dir = os.path.join(config['output_dir'], config['Training_mode'])
    if not os.path.isdir(output_dir):
        os.makedirs(output_dir)
    config['problem'] = os.path.basename(config['data_dir'])
    output_dir = os.path.join(output_dir, config['problem'],  str(config['seed']))

    config['output_dir'] = output_dir
    config['save_dir'] = os.path.join(output_dir, 'checkpoints')
    config['pred_dir'] = os.path.join(output_dir, 'predictions')
    config['tensorboard_dir'] = os.path.join(output_dir, 'tb_summaries')
    create_dirs([config['save_dir'], config['pred_dir'], config['tensorboard_dir']])
    # Save configuration as a (pretty) json file

    with open(os.path.join(output_dir, 'configuration.json'), 'w') as fp:
        json.dump(config, fp, indent=4, sort_keys=True)

    logger.info("Stored configuration file in '{}'".format(output_dir))

    return config

def z_score(x):
    mean_C = torch.mean(x, dim=1, keepdim=True)
    # Calculate standard deviation over dimension C (dim=1)
    std_C = torch.std(x, dim=1, keepdim=True)
    # Add a small epsilon to std_C to avoid division by zero if std is 0
    epsilon = 1e-16
    z_score = (x - mean_C) / (std_C + epsilon)
    return z_score

def save_data(data, name='data'):
    with open(name + '.npy', 'wb') as f:
        np.save(f, data, allow_pickle=True)
    return 

def create_dirs(dirs):
    """
    Input:
        dirs: a list of directories to create, in case these directories are not found
    Returns:
        exit_code: 0 if success, -1 if failure
    """
    try:
        for dir_ in dirs:
            if not os.path.exists(dir_):
                os.makedirs(dir_)
        return 0
    except Exception as err:
        print("Creating directories error: {0}".format(err))
        exit(-1)


def Initialization(config):
    if config['seed'] is not None:
        torch.manual_seed(config['seed'])
    device = torch.device('cuda' if (torch.cuda.is_available() and config['gpu'] != '-1') else 'cpu')
    logger.info("Using device: {}".format(device))
    if device == 'cuda':
        logger.info("Device index: {}".format(torch.cuda.current_device()))
    return device


def Data_Loader(config):
    if config['Pre_Training'] == 'Cross-domain':
        Data = data_loader.cross_domain_loader(config)
    elif config['problem'] =='TUEV':
        Data = data_loader.tuev_loader(config)
    elif config['problem'] =='TUAB':
        Data = data_loader.tuab_loader(config)
    elif config['problem'] =='CHB-MIT':
        Data = data_loader.chbmit_loader(config)
    else:
        Data = data_loader.load(config)
    return Data


class dataset_class(Dataset):

    def __init__(self, data, label, patch_size, coherence_label=None):
        super(dataset_class, self).__init__()

        self.feature = data
        self.labels = label.astype(np.int32)
        self.patch_size = patch_size
        self.coherence_label = coherence_label
        # self.__padding__()

    def __padding__(self):
        origin_len = self.feature[0].shape[1]
        if origin_len % self.patch_size:
            padding_len = self.patch_size - (origin_len % self.patch_size)
            padding = np.zeros((len(self.feature), self.feature[0].shape[0], padding_len), dtype=np.float32)
            self.feature = np.concatenate([self.feature, padding], axis=-1)

    def __getitem__(self, ind):
        x = self.feature[ind]
        x = x.astype(np.float32)
        y = self.labels[ind]  # (num_labels,) array
        if self.coherence_label is not None:
            coherence_label = self.coherence_label[ind]
        else:
            coherence_label = ind
        data = torch.tensor(x)
        label = torch.tensor(y)

        return data, label, coherence_label

    def __len__(self):
        return len(self.labels)


def print_title(text):
    title = f"           {text}          "
    border = '*' * len(title)
    print(border)
    print(title)
    print(border)


def convert_frequency(config, Data):
    problem = config['data_dir'].split('/')[-1]
    Data['All_train_data'] = get_fft(Data['All_train_data'])
    Data['train_data'] = get_fft(Data['train_data'])
    Data['val_data'] = get_fft(Data['val_data'])
    Data['test_data'] = get_fft(Data['test_data'])
    Data['max_len'] = 10
    np.save(config['data_dir'] + "/" + problem + '_f', Data, allow_pickle=True)
    return Data


def get_fft(train_data):
    fs = 128  # Sampling rate (128 Hz)
    # Define EEG bands
    eeg_bands = {'Delta': (0, 4),
                 'Theta': (4, 8),
                 'Alpha': (8, 12),
                 'Beta': (12, 30),
                 'Gamma': (30, 45)}
    F_train = np.zeros((train_data.shape[0], train_data.shape[1], 10))
    for i in range(train_data.shape[0]):
        for j in range(train_data.shape[1]):
            data = train_data[i][j]

            # Get real amplitudes of FFT (only in postive frequencies)
            fft_vals = np.absolute(np.fft.rfft(data))

            # Get frequencies for amplitudes in Hz
            fft_freq = np.fft.rfftfreq(len(data), 1.0 / fs)

            # Take the mean of the fft amplitude for each EEG band
            k = 0
            for band in eeg_bands:
                freq_ix = np.where((fft_freq >= eeg_bands[band][0]) &
                                   (fft_freq <= eeg_bands[band][1]))[0]
                F_train[i, j, k] = np.min(fft_vals[freq_ix])
                F_train[i, j, k+1] = np.max(fft_vals[freq_ix])
                k = k + 1

    return F_train

class EarlyStopping:
    def __init__(self, patience=3, verbose=False, delta=0, if_max=False, device=None):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.inf
        self.delta = delta
        self.if_max = if_max
        self.device = device

    def __call__(self, val_loss):
        if self.device is not None and self.device != 0:
            return
        if self.patience == 0:
            return
        if self.if_max:
            score = val_loss
        else:
            score = -val_loss
        if self.best_score is None:
            self.best_score = score
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
                print('Early Strop Reached.')
        else:
            self.best_score = score
            self.counter = 0
