from collections import deque
import numpy as np
import torch
import ruamel.yaml as yaml
import time
import heapq
from queue import PriorityQueue
from sortedcollections import ValueSortedDict


def is_same_subset(sub1, sub2):
    sub1 = sorted(list(sub1))
    sub2 = sorted(list(sub2))
    return sub1 == sub2


def count_subsets(subsets):
    if isinstance(subsets, torch.Tensor):
        subsets = subsets.detach().cpu().numpy().tolist()
    subsets = [tuple(subset) for subset in subsets]
    return len(set(subsets))


def count_subsets2(subsets):
    if isinstance(subsets, torch.Tensor):
        subsets = subsets.detach().cpu().numpy()
    all_feat_nums = len(set(subsets.reshape(-1).tolist()))
    return all_feat_nums


class BestSubsets(object):
    def __init__(self, max_len=100, save_path='./best_subsets.yaml'):
        self.data = ValueSortedDict()
        self.max_len = max_len
        self.best_score = -np.inf
        self.save_path = save_path

    @staticmethod
    def _normalize_subset(subset):
        if isinstance(subset, torch.Tensor):
            subset = subset.detach().cpu().numpy().tolist()
        subset = tuple(sorted(subset))
        return subset

    def append(self, subset, score):
        subset = self._normalize_subset(subset)

        if self.best_score < score:
            self.best_score = score

        if subset in self.data:
            if self.data[subset] > score:
                return

        self.data[subset] = score

        if len(self.data) > self.max_len:
            self.data.popitem(0)

    def _change_data(self, data):
        if isinstance(data, dict):
            return {self._change_data(k): self._change_data(v) for k, v in data.items()}
        if isinstance(data, (list, tuple)):
            return tuple([self._change_data(item) for item in data])
        if isinstance(data, np.ndarray) or isinstance(data, np.number):
            return data.tolist()
        return data

    def save(self):
        with open(self.save_path, 'w') as f:
            yaml.dump(self._change_data({
                'data': {v: k for k, v in self.data.items()},
                'max_len': self.max_len,
                'best_score': self.best_score,
            }), f)

    def load(self, path=None):
        if path is None:
            path = self.save_path

        with open(path, 'r') as f:
            res = yaml.load(f)
        self.data = ValueSortedDict({tuple(v): k for k, v in res['data'].items()})
        self.max_len = res['max_len']
        self.best_score = res['best_score']

    def is_content(self, subset):
        subset = self._normalize_subset(subset)
        return subset in self.data


# a = BestSubsets(2)
# a.append([1, 2, 3, 0], 0)
# a.append([1, 2, 3, 0], 1)
# a.append([1, 2, 3, 0], 2)
# a.append([1, 2, 3, 0], 1)
# a.append([1, 2, 3, 2], 3)
# a.append([1, 2, 3, 4], 5)
# a.append([1, 2, 3, 1], 3)

#
# class BestSubset(object):
#     def __init__(self):
#         self.best_score = -np.inf
#         self.best_subset = None
#         self.non_appear_times = 0
#         self.true_data = {}
#         self.best_time = None
#         self.start_time = time.time()
#         self.historys = []
#
#         self.best_sorted = []
#         self.best_sorted_maxlen = 50
#
#         self._path = None
#
#     def _check_sorted_history(self, subset, score):
#         if len(self.best_sorted) > self.best_sorted_maxlen:
#             heapq.heapreplace(self.best_sorted, (score, subset))
#         else:
#             heapq.heappush(self.best_sorted, (score, subset))
#
#     def append(self, subset, score):
#         if isinstance(subset, torch.Tensor):
#             subset = subset.detach().cpu().numpy().tolist()
#             subset = tuple(sorted(subset))
#         res = self._update_best_subset(subset, score)
#         self._check_sorted_history(subset, score)
#         return res
#
#     def _update_best_subset(self, subset, scores):
#         self.non_appear_times += 1
#         if scores >= self.best_score:
#             self.historys.append((self.best_time, self.best_score, self.best_subset))
#             self.best_time = time.time() - self.start_time
#             self.non_appear_times = 0
#             self.best_subset = subset
#             self.best_score = scores
#             return self.best_score
#
#     def add_true_result(self, subset, result):
#         subset = tuple(sorted(list(subset)))
#         if subset not in self.true_data:
#             self.true_data[subset] = []
#         self.true_data[subset].append(result)
#
#     def true_subset_nums(self, subset):
#         subset = tuple(sorted(list(subset)))
#         if subset in self.true_data:
#             return len(self.true_data[subset])
#         return 0
#
#     def _change_data(self, data):
#         if isinstance(data, dict):
#             return {k: self._change_data(v) for k, v in data.items()}
#         if isinstance(data, (list, tuple)):
#             return [self._change_data(item) for item in data]
#         if isinstance(data, np.ndarray) or isinstance(data, np.number):
#             return data.tolist()
#         return data
#
#     def save(self, path):
#         self._path = path
#         with open(path, 'w') as f:
#             yaml.dump(self._change_data({
#                 'best_appear_time': self.best_time,
#                 'best_score': self.best_score,
#                 'best_subset': self.best_subset,
#                 'appear_times': self.non_appear_times,
#                 'history': self.historys,
#                 'best_sorted_history': self.best_sorted,
#             }), f)



