import numpy as np


class Batcher(object):
    
    def __init__(self, B, data, labels, todense = False, shuffle = True, rnd_seed = None):
        self.data = data 
        self.labels = labels 
        self.B = B
        self.N = data.shape[0]        
        self.todense = todense
        self.original_indices = np.arange(self.N)        
        self.shuffle = shuffle
        
        
        self.rnd_seed = rnd_seed if rnd_seed is not None else np.random
        
        
        assert self.B <= self.N  
        
        self.reset()
        
    def reset(self):
        self.curr_idx = 0
        
        if self.shuffle:
            perm = self.rnd_seed.permutation(self.N)
            self.data = self.data[perm,:]
            self.labels = self.labels[perm]
            self.original_indices = self.original_indices[perm]
        
    def get_batch(self):
        
        if self.curr_idx + self.B > self.N:
            self.reset()
        
        
        self.current_range = range(self.curr_idx,self.curr_idx+self.B)
        
        xbatch = self.data[self.curr_idx:self.curr_idx+self.B,:]
        lxbatch = self.labels[self.curr_idx:self.curr_idx+self.B]        

        #simpler to save it here than to recompute on demand...
        self.curr_orig_idxs = self.original_indices[self.curr_idx:self.curr_idx+self.B]
        
        self.curr_idx += self.B 
        
        batch_end = self.curr_idx + self.B > self.N
        
        if self.todense:
            xbatch = xbatch.todense()
        
        
        return xbatch,lxbatch, batch_end

        



def eval_cost(M,batcher, n_batches):
    
    
    val_lst = np.zeros(0)
    val_orig_lst = np.zeros(0)
    pos_lst = np.zeros(0)
    neg_lst = np.zeros(0)

    
    for i in range(n_batches):
        xbatch, xlabels , _  = batcher.get_batch()  
        ybatch, ylabels , _  = batcher.get_batch()  


        feed_dict = {M.t_x_ph:xbatch, 
                    M.t_x_labels_ph:xlabels,
                    M.t_y_ph:ybatch, 
                    M.t_y_labels_ph:ylabels,
                    }
        (rval,rval_orig,rmask,rposv,rnegv) = M.sess.run((M.t_value, M.t_value_orig,
                                    M.t_label_eq_mask, M.positive_vals,M.negative_vals,                                    
                                   ), 
                         feed_dict = feed_dict
                        )        

        pos_mean = (rposv.sum(axis = 1) / np.maximum(1,rmask.sum(axis = 1))).mean() 
        neg_mean = (rnegv.sum(axis = 1) / np.maximum(1,(1.-rmask).sum(axis = 1))).mean()

        
        val_lst = np.append(val_lst,[rval])
        val_orig_lst = np.append(val_orig_lst,[rval_orig])
        pos_lst = np.append(pos_lst,[pos_mean])
        neg_lst = np.append(neg_lst,[neg_mean])
        
    
    return val_lst.mean(), val_orig_lst.mean(), pos_lst.mean(), neg_lst.mean()









class fitCB:
    
    def __init__(self, 
                 verbose_period = None,
                 log_period = None, 
                 eval_period = None,
                 batch_cnt_to_cb = None,
                 train_eval_batcher = None,
                 test_eval_batcher = None,
                 n_eval_batches = None,
                 do_knn_test = None
                ):


        self.do_knn_test = do_knn_test

        self.verbose_period = verbose_period


        self.log_period = verbose_period if log_period is None else log_period
        self.eval_period = verbose_period if eval_period is None else eval_period


        
        self.batch_cnt_to_cb = batch_cnt_to_cb

        self.train_eval_batcher = train_eval_batcher
        self.test_eval_batcher = test_eval_batcher
        self.n_eval_batches = n_eval_batches  

        
        
        self.proper_invocation_count = 0
        

        self.loss_lst = np.zeros(0)
        self.val_lst = np.zeros(0)

        self.train_eval_lst = []
        self.test_eval_lst = []


            
    def cb(self,M,curr_feed_dict):
                
        if self.batch_cnt_to_cb is None:
            if (not (M.global_epoch % self.log_period == 0)) or (M.batch_counter != 0):
                return 
        else:
            if  (M.batch_counter % self.batch_cnt_to_cb != 0): 
                return 


        (rloss, rval,rmask,rposv,rnegv) = M.sess.run((M.t_loss, M.t_value, 
                                    M.t_label_eq_mask, M.positive_vals,M.negative_vals
                                   ), 
                         feed_dict = curr_feed_dict
                        )        
        
        self.loss_lst = np.append(self.loss_lst,[rloss])
        self.val_lst = np.append(self.val_lst,[rval])
        
            
                
        if (self.verbose_period is not None) and (M.global_epoch % self.verbose_period == 0):
            print(f'Epoch: {M.global_epoch}, batch:{M.batch_counter}, loss: {rloss}, value: {rval} \n ')
        
            l2inf_val,l2_val  = M.sess.run((M.t_weight_l2_infty,M.t_weight_l2))
            l2_val *= M.K
            print(f'    l2inf val     : {l2inf_val}, l2_val*K:    {l2_val}')
            
            pos_mean = (rposv.sum(axis = 1) / np.maximum(1,rmask.sum(axis = 1))).mean() 
            neg_mean = (rnegv.sum(axis = 1) / np.maximum(1,(1.-rmask).sum(axis = 1))).mean()
            print(f'    pos mean   : {pos_mean}')
            print(f'    neg mean   : {neg_mean}')
            

        #now the eval 
        if (
            (self.train_eval_batcher is not None) and 
            (M.global_epoch % self.eval_period == 0) and 
            (M.batch_counter == 0)
            ):                
            assert self.test_eval_batcher is not None, 'supply both'


            knn_tr_score = knn_ts_score = np.nan
            
                        
            rv,rvo, rp,rn = eval_cost(M,self.train_eval_batcher,self.n_eval_batches)
            self.train_eval_lst.append((rv,rp,rn,rvo, knn_tr_score))            
            print(f'---------------------------')
            print(f'    Train Eval   : loss val: {rv}, Same Class val:{rp}, Different Class:{rn}')
            
            rv,rvo,rp,rn = eval_cost(M,self.test_eval_batcher,self.n_eval_batches)
            self.test_eval_lst.append((rv,rp,rn,rvo,knn_ts_score))
            print(f'    Test Eval   : loss val: {rv}, Same Class val:{rp}, Different Class:{rn}')

            
        
        self.proper_invocation_count += 1
        
        return 
        





def jitter(v, eps = 1e-2):
    v = v + eps*np.random.randn(*v.shape)
    return v

