import os
import copy
import numpy as np
import ruamel.yaml as yaml


class PrettySafeLoader(yaml.SafeLoader):
    def construct_python_tuple(self, node):
        return tuple(self.construct_sequence(node))


PrettySafeLoader.add_constructor(
    u'tag:yaml.org,2002:python/tuple',
    PrettySafeLoader.construct_python_tuple)


def load_results(path):
    with open(path, 'r') as f:
        return yaml.load(f, Loader=PrettySafeLoader)

# max(data, key=lambda x:x['metaFE_score'])
# max(data, key=lambda x:x['true_results']['f1'])
# max(data, key=lambda x:x['true_results']['val_f1'])


def append_to_txt(path, string, prefix='', suffix=''):
    path_dir = os.path.split(path)[0]
    if not os.path.exists(path_dir):
        os.makedirs(path_dir)

    with open(path, 'a+') as f:
        f.write(prefix + str(string) + suffix + '\n')


class DataLoaderSampler(object):
    def __init__(self, data_loader, device=None):
        self.data_loader = data_loader
        self._iter = iter(self.data_loader)
        self.device = device

    @property
    def loader_length(self):
        return len(self.data_loader)

    def get(self):
        try:
            item = next(self._iter)
        except StopIteration:
            self._iter = iter(self.data_loader)
            item = self.get()
        if self.device is not None:
            item = tuple([i.to(self.device) for i in item])
        return item

    def __iter__(self):
        return self

    def __next__(self):
        return self.get()


class Filter(object):
    def __init__(self, alpha=0):
        self.alpha = alpha
        self.memory = None

    def update(self, value):
        if self.memory is None:
            self.memory = value
        else:
            self.memory = self.alpha * self.memory + (1 - self.alpha) * value
        return self.memory


class EarlyStopping(object):
    def __init__(self, net, patience=10, filter=0):
        self.net = net
        self.min_loss = None
        self.patience = patience
        self._i = 0
        self.weight = None
        self.filter = Filter(filter)

    def reset(self):
        self._i = 0
        # self.min_loss = None
        # self.weight = None
        
    def update(self, loss):
        loss = self.filter.update(loss)

        if self.min_loss is None:
            self.min_loss = loss
            self._i = 0
            self.weight = copy.deepcopy(self.net.state_dict())
        else:
            if loss <= self.min_loss:
                self.min_loss = loss
                self._i = 0
                self.weight = copy.deepcopy(self.net.state_dict())
            else:
                self._i += 1
        return self.is_stop

    @property
    def is_stop(self):
        if self.patience <= 0:
            return False
        return self._i >= self.patience

    @property
    def best_weight(self):
        return self.weight


def change_data(data):
    if isinstance(data, dict):
        return {change_data(k): change_data(v) for k, v in data.items()}
    if isinstance(data, (list, tuple)):
        return [change_data(item) for item in data]
    if isinstance(data, np.ndarray) or isinstance(data, np.number):
        return data.tolist()
    return data
