import numpy as np
import torch
from copy import deepcopy
import pandas as pd

from env import Env
from benchmark import Benchmark

class PBR:
    env:Env

    def __init__(self, 
                 vec: np.array, # dataset
                 eps: float = 0.1, # error bias
                 delta: float = 0.05, # confidence parameter
                 lmbda: float = 0.2, # regularization term
                 k: int = 10, # subset size
                 update_inf:bool = True,
                 adaptive_delta:bool = True,
                 flex_rw_change:bool = False,
                 random_build = False,
                 corr_noise=0.,
                 sharpness = None
                 ):
        if torch.cuda.is_available():
            self.DEVICE = 'cuda'
        else:
            self.DEVICE = 'cpu'
        self.vecs = self.preproc(torch.from_numpy(vec).to(self.DEVICE)) # n x d
        self.n, self.d = self.vecs.shape
        self.k = k
        self.delta = self.delta_eff = delta
        self.eps = eps
        self.update_inf = update_inf
        self.lmbda = lmbda
        
        
        self.corr_noise = corr_noise
        if corr_noise:
            noise_term = torch.randn_like(self.vecs).to(self.DEVICE)
            noise_term = noise_term/torch.norm(noise_term, p=2, dim=-1, keepdim=True)
            print(noise_term.norm(-1))
            tmp_vecs = self.vecs + self.corr_noise * noise_term
            tmp_vecs = tmp_vecs/torch.norm(tmp_vecs, p=2, dim=-1, keepdim=True)
            self.corr = torch.matmul(tmp_vecs, tmp_vecs.T)
        else:
            self.corr = torch.matmul(self.vecs, self.vecs.T) # n x n
        #self.corr = self.corr.clip(min=0)
        
        if not random_build:
            self.info_mat = self.corr
            self.info_mat = self.info_mat - torch.eye(self.n).to(self.DEVICE)
        else:
            self.info_mat = torch.rand(size=[self.n,self.n]).to(self.DEVICE)
        
        self.rw = None
        self.played_item_mask = torch.zeros(self.n).to(self.DEVICE)
        self.unplayed_item_mask = torch.ones(self.n).to(self.DEVICE)

        self.remaining_items = set([i for i in range(self.n)])

        self.W = torch.zeros(self.n,self.n).float().to(self.DEVICE)
        self.W_emp = torch.zeros(self.n,self.n).float().to(self.DEVICE)
        self.W_inf_count = torch.zeros(self.n,self.n).to(self.DEVICE)
        self.num_items_pruned_from_inf = 0

        self.env = None

        self.adaptive_delta = adaptive_delta
        self.flex_rw_change = flex_rw_change

        self.benchmark = Benchmark(self)

        self.sharpness = sharpness
    
    def get_result_dict(self):
        d = dict()
        for key in ['n', 'd', 'eps', 'delta', 'k', 'num_steps', 'pr_diff', 'num_items_pruned_from_inf', 'update_inf', 'corr_noise']:
            d[key] = self.__dict__.get(key, None)
        d['num_valid_items'] = self.env.num_valid_items
        d['max_corr'] = self.env.max_corr
        d['sharpness'] = self.env.sharpness
        return d

    
    def reset_delta_eps(self, delta_=None):
        if delta_ is None:
            delta_ = (np.ceil(self.n/self.k)/self.delta)**-1
        self.delta_eff = delta_
        self.m = 2 * np.log(1/delta_)/(self.eps**2)
        self.cb_const = np.log(1/delta_)

    
    def run_sim(self, q_vec, log_freq=None):
        q_vec = torch.from_numpy(q_vec).to(self.DEVICE)
        q_vec = q_vec/torch.norm(q_vec)

        if log_freq is not None:
            log_df = pd.DataFrame()

        self.num_steps = 0
        self.pr_diff_list = []
        if self.flex_rw_change:
            self.adaptive_delta = True
            self.cum_delta = 0
            self.reset_delta_eps((np.ceil(self.n)/self.delta)**-1)
        elif self.adaptive_delta:
            self.cum_delta = 0
            self.reset_delta_eps()
        else:
            self.reset_delta_eps()

        if self.sharpness is not None:
            self.env = Env(self.vecs, q_vec, eps=self.eps, device = self.DEVICE, sharpness=self.sharpness)
        else:
            self.env = Env(self.vecs, q_vec, eps=self.eps, device = self.DEVICE)
        G = self.build_initial_set()

        while len(self.remaining_items) > 1:
            win = self.env.play_set(G)
            self.pr_diff = self.env.get_pr_diff(G)
            self.pr_diff_list.append(self.pr_diff)
            if self.num_steps%3000 == 0:
                print(f'Iteration: {self.num_steps}; pr_diff: {self.pr_diff:.4f}, RW: {self.rw}, num_remaining_items: {len(self.remaining_items)}, m: {self.m:.4f}, cum_delta: {self.cum_delta:.4f}, eff_delta: {self.delta_eff:.4f}, items_pruned: {self.num_items_pruned_from_inf}')
                if self.num_steps == 0:
                    print(self.get_result_dict())
            if log_freq is not None:
                if self.num_steps%log_freq==0:
                    log_df = self.append_dict_to_df(df=log_df, dict_row={'Iteration': self.num_steps,
                                                                'pr_diff': self.pr_diff,
                                                                'num_remaining_items': len(self.remaining_items),
                                                                'items_pruned': self.num_items_pruned_from_inf})

            self.num_steps += 1
            self.update_w(win, G)
            G = self.update_set(G, win)
            self.num_steps +=1
        
        print(f'Iteration: {self.num_steps}; pr_diff: {self.pr_diff}, RW: {self.rw}, num_remaining_items: {len(self.remaining_items)}, played_set: {G}, items_pruned: {self.num_items_pruned_from_inf}')
        return log_df

    

    def prune_all_items(self, U:torch.Tensor, P:torch.Tensor, N:torch.Tensor):
        elim_mask = U.min(-1)[0].lt(0.5)
        elim_mask_2 = N[self.rw,:].ge(self.m) * P[self.rw,:].ge(0.5 - self.eps/2)
        elim_mask[self.rw] = 0
        elim_mask_2[self.rw] = 0 # self.rw is epsilon-optimal w.r.t. to it self
        items_to_elim = self.mask_to_set(torch.logical_or(elim_mask, elim_mask_2))
        if len(set.intersection(items_to_elim, self.remaining_items)):
            pass
        self.num_items_pruned_from_inf += len(set.intersection(items_to_elim, self.remaining_items))
        self.eliminate(items_to_elim)
    
    def build_initial_set(self):
        G = set()        
        # running winner = maximally correlated item
        self.rw = self.info_mat.sum(-1).argmax().item()
        self.update_masks(self.rw)
        G.add(self.rw)

        for _ in range(self.k-1):
            if self.unplayed_item_mask.sum() != 0:
                itm = self.select_item_to_add(self.played_item_mask, self.unplayed_item_mask)
                self.update_masks(itm)
                G.add(itm)
        return G
    
    def select_item_to_add(self, played_item_mask, unplayed_item_mask, by_n = False):
        '''
        item least correlated with played items
        '''
        if not by_n:
            return ((self.info_mat * (played_item_mask.view(1,-1))).mean(1) + played_item_mask*1e6).argmin().item()
        else:
            N = (self.W + self.W.T).sum(-1)
            return ((self.info_mat * (played_item_mask.view(1,-1))).mean(1) + N + played_item_mask*1e9).argmin().item()
    
    def update_masks(self, itm):
        self.played_item_mask[itm] = 1
        self.unplayed_item_mask[itm] = 0
    
    def update_w(self, win, itms):
        itms = deepcopy(itms)
        itms.discard(win)
        for itm in itms:
            if win == itm:
                raise ValueError()
            self.W[win, itm] += 1
            self.W_emp[win, itm] += 1
        if self.update_inf:
            itms = torch.Tensor(list(itms)).long()
            p_cond_win = self.p_cond_fn(win, itms) # n x itms
            r_cond_win = self.info_fn(p_cond_win)
            self.W[:,itms] += r_cond_win * p_cond_win
            self.W[itms,:] += (r_cond_win * (1 - p_cond_win)).T
            p_cond_lose = self.p_cond_fn(win, itms, mode = 'lose') # n x itms
            r_cond_lose = self.info_fn(p_cond_lose)
            self.W[win,:] += ((r_cond_lose * p_cond_lose)).sum(-1)
            self.W[:,win] += (r_cond_lose * (1 - p_cond_lose)).sum(-1)
            
    
    def update_set(self, G, win):

        self.old_rw = self.rw *1

        H = set()
        
        N = self.W + self.W.T
        P = (self.W/(N + 1e-9))
        U = P + torch.sqrt(self.cb_const/(2*N))

        G_mask = self.set_to_mask(G)
        G_mask_wo_win = deepcopy(G_mask); G_mask_wo_win[self.rw] = 0
        candidate = None
        W_mask = (N[:, self.rw].ge(self.m) * G_mask_wo_win * (P[self.rw,:].lt(0.5 - self.eps/2))).bool()
        if W_mask.sum().item()==0 and self.flex_rw_change:
            W_mask = U[self.rw,:].lt(0.5)
            if W_mask.sum().item()>0:
                print('flex elim')
        if W_mask.sum().item()>0:
            candidate = (P[:, self.rw]*W_mask).argmax().item()
            H.update(self.mask_to_set(W_mask))
        else:
            H.add(self.rw)
                
        G_mask_not_W =  deepcopy(G_mask_wo_win) * (~W_mask)
        G_mask_not_W_ge_m = G_mask_not_W * N[:, self.rw].ge(self.m)
        
        pot_win_mask = (U[:,self.rw] - self.eps).ge(0.5) * (N[:, self.rw].lt(self.m))
        pot_win_mask = torch.logical_and(pot_win_mask, U.min(-1)[0].ge(0.5))
        G_mask_not_W_pass = G_mask_not_W * pot_win_mask
        H.update(self.mask_to_set(G_mask_not_W_pass))
        G_mask_not_W_elim = torch.logical_or(G_mask_not_W * (~pot_win_mask), G_mask_not_W_ge_m)
        if candidate is not None:
            self.eliminate(set([self.rw]))
            G_mask_not_W_elim[self.rw] = 1
        self.eliminate(self.mask_to_set(G_mask_not_W_elim))
        for _ in range(int(G_mask_not_W_elim.sum().item())):
            if self.unplayed_item_mask.sum() != 0:
                itm = self.select_item_to_add(self.played_item_mask, self.unplayed_item_mask, by_n = False)
                self.update_masks(itm)
                H.add(itm)
        
        self.prune_all_items(U, P, N)

        for itm in self.mask_to_set(W_mask):
            self.rw_inherit(new_rw=itm, old_rw=self.rw)

        if candidate is not None:
            self.change_rw(candidate)

        if self.rw not in G:
            raise ValueError()

        if len(G) != len(H) and len(self.remaining_items) >= self.k:
            raise ValueError()
        
        return H

    def change_rw(self, new_rw):
        self.rw = new_rw
        if self.adaptive_delta:
            self.cum_delta += self.delta_eff
        
    def eliminate(self, itms:set):
        if self.env.optimal_item in itms:
            pass
        self.remaining_items.difference_update(itms)
        if self.flex_rw_change:
            self.reset_delta_eps(delta_=(self.delta - self.cum_delta)/(len(self.remaining_items)))
        elif self.adaptive_delta:
            self.reset_delta_eps(delta_=(self.delta - self.cum_delta)/(len(self.remaining_items)/self.k))

    
    def rw_inherit(self, new_rw, old_rw):
        self.W[new_rw,:] = self.W[new_rw,:] + self.W[old_rw,:]
        self.W[:,new_rw] = self.W[:,new_rw] + self.W[:,old_rw]

    def p_cond_fn(self, itm_i, itm_k, mode = 'win', threshold = 0.85):
        if type(itm_i) == torch.Tensor:
            corr_ik = self.corr[itm_k,itm_i]
        else:
            corr_ik = self.corr[itm_i,itm_k]
        if mode == 'win':
            out = 1 - torch.arccos((self.corr[:,itm_i].view(self.n,-1) - corr_ik - self.corr[:,itm_k].view(self.n,-1) + 1)/\
                                   (2*torch.sqrt((1-corr_ik) * (1-self.corr[:,itm_k].view(self.n,-1))) + self.lmbda))/np.pi
            out[out<=threshold] = 0.5
            if type(itm_i) == torch.Tensor or type(itm_k) == torch.Tensor:
                out[itm_i,:] = 0.5; out[itm_k,:] = 0.5
            else:
                out[itm_i] = 0.5; out[itm_k] = 0.5
            #if (out != out).any():
            #    raise ValueError()
            return out
        elif mode == 'lose':
            return self.p_cond_fn(itm_k, itm_i, mode = 'win')

    @staticmethod
    def info_fn(arr: torch.FloatTensor):
        out = (1 - (-arr*torch.log2(arr) -(1-arr)*torch.log2(1-arr)))
        #if (out != out).any():
        #    raise ValueError()
        return out
    
    def set_to_mask(self, ls:set):
        out = torch.zeros(self.n).to(self.DEVICE)
        idxs = torch.Tensor(list(ls)).long().to(self.DEVICE)
        out[idxs] = 1
        return out

    def mask_to_set(self, mask:torch.Tensor):
        return set(torch.nonzero(mask.flatten(), as_tuple=True)[0].tolist())
    
    @staticmethod
    def preproc(vecs):
        return vecs / torch.norm(vecs, dim=-1, keepdim=True)
    
    @staticmethod
    def append_dict_to_df(df, dict_row):
        # Convert the dict to DataFrame
        dict_df = pd.DataFrame([dict_row])
        # Append the dict DataFrame to the original DataFrame
        df = pd.concat([df, dict_df], ignore_index=True)
        return df

if __name__ == "__main__":

    vecs = np.random.randn(5000, 16)
    pbr = PBR(vecs, update_inf=True)
    pbr.run_sim(vecs[5,:])

#TODO: figure out threshold, check info weighting on inferred updates - make it more harsh?
#TODO: why does wrongful elimination only happen when it is the running winner?