import random

import numpy as np
import torch

import conf
from utils.custom_exceptions import *
from scipy.cluster.vq import vq
from sklearn.cluster import KMeans
import math

device = torch.device("cuda:{:d}".format(conf.args.gpu_idx) if torch.cuda.is_available() else "cpu")
torch.cuda.set_device(conf.args.gpu_idx)  # this prevents unnecessary gpu memory allocation to cuda:0 when using estimator

class ActivePriorityFIFO:
    def __init__(self, capacity, pop="min"):
        # feat, cls, domain, entropy
        self.correct_mem = [[], [], [], []]  # for correct samples
        self.wrong_mem = [[], [], [], []]  # for wrong samples
        self.u_mem = [[], [], [], []]  # for unlabeled samples : wait to be labeled
        self.capacity = capacity
        self.pop = pop
        pass

    def set_memory(self, state_dict):  # for tta_attack
        self.correct_mem = [ls[:] for ls in state_dict['correct_mem']]
        self.wrong_mem = [ls[:] for ls in state_dict['wrong_mem']]
        self.u_mem = [ls[:] for ls in state_dict['u_mem']]

        if 'capacity' in state_dict.keys():
            self.capacity = state_dict['capacity']

    def save_state_dict(self):
        dic = {}
        dic['correct_mem'] = [ls[:] for ls in self.correct_mem]
        dic['wrong_mem'] = [ls[:] for ls in self.wrong_mem]
        dic['u_mem'] = [ls[:] for ls in self.u_mem]
        dic['capacity'] = self.capacity
        return dic

    def get_memory(self):
        dic = {'correct_mem': self.correct_mem,
               'wrong_mem': self.wrong_mem,
               'u_mem': self.u_mem,
               'capacity': self.capacity}
        return dic

    def get_correct_memory(self):
        return self.correct_mem

    def get_wrong_memory(self):
        return self.wrong_mem

    def get_u_memory(self):
        return self.u_mem

    def get_occupancy(self, mem):
        return len(mem[0])  # need to be checked

    def add_instance(self, instance):
        raise NotImplementedError

    def add_correct_instance(self, instance):
        assert (len(instance) == 4)

        if self.get_occupancy(self.correct_mem) >= self.capacity:
            self.remove_instance(self.correct_mem, pop=self.pop)

        for i, dim in enumerate(self.correct_mem):
            dim.append(instance[i])

    def add_wrong_instance(self, instance):
        assert (len(instance) == 4)

        if self.get_occupancy(self.wrong_mem) >= self.capacity:
            self.remove_instance(self.wrong_mem, pop=self.pop)

        for i, dim in enumerate(self.wrong_mem):
            dim.append(instance[i])

    def add_u_instance(self, instance):
        assert (len(instance) == 4)

        if self.get_occupancy(self.u_mem) >= self.capacity:
            self.remove_instance(self.u_mem)

        for i, dim in enumerate(self.u_mem):
            dim.append(instance[i])

    def remove_instance(self, mem, pop=None):
        if pop == "min":
            target_idx = np.argmin(mem[3])
        elif pop == "max":
            target_idx = np.argmax(mem[3])
        else:
            target_idx = 0
        self.remove_instance_by_index(mem, target_idx)

    def remove_instance_by_index(self, mem, index):
        for dim in mem:
            dim.pop(index)
        return

    def remove_u_instance_by_index(self, index):
        self.remove_instance_by_index(self.u_mem, index)

    def reset(self):
        self.u_mem = [[], [], [], []]
        self.correct_mem = [[], [], [], []]
        self.wrong_mem = [[], [], [], [], []]
        