import os
import time
import pprint
from TSception import TSception
from FeelNet5 import FeelNet
from FeelNet6 import FeelNet1

from baselines import deepConvNet, shallowConvNet, eegNet, ERTNet, EEGTCNet
from eeg_dataset import *
from torch.utils.data import DataLoader
from sklearn.metrics import confusion_matrix, accuracy_score, f1_score


def set_gpu(x):
    torch.set_num_threads(1)
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ['CUDA_VISIBLE_DEVICES'] = x
    print('using gpu:', x)


def seed_all(seed):
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    np.random.seed(seed)


def ensure_path(path):
    if os.path.exists(path):
        pass
    else:
        os.makedirs(path)


class Averager():

    def __init__(self):
        self.n = 0
        self.v = 0

    def add(self, x):
        self.v = (self.v * self.n + x) / (self.n + 1)
        self.n += 1

    def item(self):
        return self.v


def count_acc(logits, label):
    pred = torch.argmax(logits, dim=1)
    return (pred == label).type(torch.cuda.FloatTensor).mean().item()


class Timer( ):

    def __init__(self):
        self.o = time.time()

    def measure(self, p=1):
        x = (time.time() - self.o) / p
        x = int(x)
        if x >= 3600:
            return '{:.1f}h'.format(x / 3600)
        if x >= 60:
            return '{}m'.format(round(x / 60))
        return '{}s'.format(x)

_utils_pp = pprint.PrettyPrinter()
def pprint(x):
    _utils_pp.pprint(x)


def get_model(args):
    if args.model == 'FeelNet':
        model = FeelNet(
            nb_classes=args.num_class, Chans=args.input_shape[1], dropout_rate=args.dropout, Samples=args.input_shape[2], Filter=args.adaptive_filter
        )

    if args.model == 'FeelNet1':
        model = FeelNet1(
            nb_classes=args.num_class, Chans=args.input_shape[1], dropout_rate=args.dropout, Samples=args.input_shape[2], Filter=args.adaptive_filter
        )

    if args.model == 'TSception':
        model = TSception(
            num_classes=args.num_class, input_size=args.input_shape,
            sampling_rate=args.sampling_rate, num_T=args.T, num_S=args.T,
            hidden=args.hidden, dropout_rate=args.dropout)

    if args.model == 'deepConvNet':
        model = deepConvNet(
            nChan=args.input_shape[1], nTime=args.input_shape[2], nClass=args.num_class, dropoutP=args.dropout
        )

    if args.model == 'shallowConvNet':
        model = shallowConvNet(
            nChan=args.input_shape[1], nTime=args.input_shape[2], nClass=args.num_class, dropoutP=args.dropout
        )

    if args.model == 'eegNet':
        model = eegNet(
            nChan=args.input_shape[1], nTime=args.input_shape[2], nClass=args.num_class, dropoutP=args.dropout
        )

    if args.model == 'ERTNet':
        model = ERTNet(
            nb_classes=args.num_class, Chans=args.input_shape[1], dropout_rate=args.dropout, Samples=args.input_shape[2] 
        )

    if args.model == 'EEGTCNet':
        model = EEGTCNet(
            nChan=args.input_shape[1], nTime=args.input_shape[2], nClass=args.num_class, dropoutP=args.dropout
        )

    return model


def get_dataloader(data, label, batch_size, shuffle=True):
    # load the data
    dataset = eegDataset(data, label)
    loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle, pin_memory=True)
    return loader


def get_metrics(y_pred, y_true, classes=None):
    acc = accuracy_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred)
    if classes is not None:
        cm = confusion_matrix(y_true, y_pred, labels=classes)
    else:
        cm = confusion_matrix(y_true, y_pred)
    return acc, f1, cm


def get_trainable_parameter_num(model):
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

    return total_params

def L1Loss(model, Lambda):
    w = torch.cat([x.view(-1) for x in model.parameters()])
    err = Lambda * torch.sum(torch.abs(w))
    return err


def generate_TS_channel_order(original_order: list):
    """
    This function will generate the channel order for TSception
    Parameters
    ----------
    original_order: list of the channel names

    Returns
    -------
    TS: list of channel names which is for TSception
    """
    chan_name, chan_num, chan_final = [], [], []
    for channel in original_order:
        chan_name_len = len(channel)
        k = 0
        for s in [*channel[:]]:
            if s.isdigit():
               k += 1
        if k != 0:
            chan_name.append(channel[:chan_name_len-k])
            chan_num.append(int(channel[chan_name_len-k:]))
            chan_final.append(channel)
    chan_pair = []
    for ch, id in enumerate(chan_num):
        if id % 2 == 0:
            chan_pair.append(chan_name[ch] + str(id-1))
        else:
            chan_pair.append(chan_name[ch] + str(id+1))
    chan_no_duplicate = []
    [chan_no_duplicate.extend([f, chan_pair[i]]) for i, f in enumerate(chan_final) if f not in chan_no_duplicate]
    return chan_no_duplicate[0::2] + chan_no_duplicate[1::2]


if __name__ == "__main__":
    # example of using generate_TS_channel_order()
    original_order = ['Fp1', 'AF3', 'F3', 'F7', 'FC5', 'FC1', 'C3', 'T7', 'CP5', 'CP1', 'P3', 'P7', 'PO3',
                      'O1', 'Oz', 'Pz', 'Fp2', 'AF4', 'Fz', 'F4', 'F8', 'FC6', 'FC2', 'Cz', 'C4', 'T8', 'CP6',
                      'CP2', 'P4', 'P8', 'PO4', 'O2']
    TS = generate_TS_channel_order(original_order)
    print(TS)
    print('done')







