import numpy as np
import torch
from copy import deepcopy
from operator import itemgetter
import random

from env import Env
from benchmark import Benchmark

class DKWT:
    env:Env

    def __init__(self, 
                 vec: np.array, # dataset
                 eps: float = 0.1, # error bias
                 delta: float = 0.05, # confidence parameter
                 lmbda: float = 0.2, # regularization term
                 k: int = 10, # subset size
                 sharpness = None,
                 lenient = False
                 ):
        self.vecs = self.preproc(torch.from_numpy(vec)) # n x d
        self.n, self.d = self.vecs.shape
        self.k = k
        self.delta = self.delta_eff = delta
        self.eps = eps
        if not lenient:
            self.eff_eps = 4*eps/(k +2*eps*k - 4*eps)
        else:
            self.eff_eps = self.eps
        print(f'Effective epsilon: {self.eff_eps}')
        self.lmbda = lmbda

        self.env = None

        self.benchmark = Benchmark(self)

        self.all_items = set(np.arange(self.n))
        self.remaining_items = deepcopy(self.all_items)
        self.rw = None

        self.sharpness = sharpness
    
    def get_result_dict(self):
        d = dict()
        for key in ['n', 'd', 'k', 'num_steps', 'eps', 'delta', 'eff_eps', 'pr_diff']:
            d[key] = self.__dict__[key]
        d['num_valid_items'] = self.env.num_valid_items
        d['max_corr'] = self.env.max_corr
        d['sharpness'] = self.env.sharpness
        return d
    
    @staticmethod
    def preproc(vecs):
        return vecs / torch.norm(vecs, dim=-1, keepdim=True)
    
    def run_sim(self, q_vec):
        q_vec = torch.from_numpy(q_vec)
        q_vec = q_vec/torch.norm(q_vec)

        self.num_steps = 0
        if self.sharpness is not None:
            self.env = Env(self.vecs, q_vec, eps=self.eps, device = 'cpu', sharpness=self.sharpness)
        else:
            self.env = Env(self.vecs, q_vec, eps=self.eps, device = 'cpu')

        s = 0 

        while s <= np.ceil(self.n/(self.k-1))-1 and len(self.remaining_items)>1:
            tmp_k = self.k-1 if self.rw is not None else self.k
            if len(self.remaining_items)<=self.k:
                G = set((list(self.remaining_items) 
                         + [i for i in random.sample(self.all_items, k=tmp_k) 
                            if i not in  self.remaining_items])[:tmp_k])
            else:
                G = set(random.sample(self.remaining_items, k=tmp_k))
            if self.rw is not None:
                G.add(self.rw)
            win = self.alg_2(G, self.delta/(np.ceil(self.n/(self.k-1))-1), self.eff_eps)
            self.rw = win
            self.remaining_items.difference_update(G)
            s+=1
    
    def alg_1(self, G, y, h):
        T = 8 * np.log(4/y)/(h**2)
        cnt = {k:0 for k in G}
        if False:
            for _ in range(int(np.ceil(T))):
                x = self.env.play_set(G)
                self.pr_diff = self.env.get_pr_diff(G)
                if self.num_steps % 50000 == 0:
                    print(f'Iteration: {self.num_steps}, remaining items: {len(self.remaining_items)}, running winner: {self.rw}, pr_diff = {self.pr_diff}')
                self.num_steps += 1
                cnt[x] += 1
        else:
            x_batch = self.env.play_set_m_times(G, m=int(np.ceil(T)))
            self.pr_diff = self.env.get_pr_diff(G)
            for x in x_batch:
                if self.num_steps % 200000 == 0:
                    print(f'Iteration: {self.num_steps}, remaining items: {len(self.remaining_items)}, running winner: {self.rw}, pr_diff = {self.pr_diff}')
                self.num_steps += 1
                cnt[x] += 1
        cnt = sorted(cnt.items(), key = itemgetter(1), reverse=True)
        if cnt[0][1] > cnt[1][1] + h*T:
            return cnt[0][0], cnt[0][0]
        else:
            return None, cnt[0][0]
    
    def alg_2(self, G, y, h_min):
        s = 1
        y_fn = lambda s: 6 *y / ((np.pi**2) * s**2)
        h_fn = lambda s: 2**(-s-1)
        win = None
        while win is None:
            win, mode = self.alg_1(G, y_fn(s), h_fn(s))
            if (8 * np.log(4/y)/(h_fn(s)**2)) >= (2 * np.log(2/y)/(h_min**2)):
                break
            s = s + 1
        return mode
    
if __name__ == "__main__":

    vecs = np.random.randn(1000, 8)
    dkwt = DKWT(vecs)
    dkwt.run_sim(vecs[5,:])
