import numpy as np
import pandas as pd
import scipy as sp
import cvxpy as cp
from itertools import islice
import multiprocessing
from multiprocessing import Pool



import torch
import copy
import time
import torch.nn as nn
from tqdm import tqdm
from optimalfair.algorithm.classifierbase import basicprocess
from optimalfair.utils.models import *
from optimalfair.utils.model_utils import *

class classifier(basicprocess):
    def __init__(self, dataset, options, name=''):
        super().__init__(dataset, options, name)
        self.dual_params_lr = self.options['dual_params_lr']
        self.dual_params_bound = self.options['dual_params_bound']
        self.tau = self.options['tau']
        self.post_num_round = self.options['post_num_round']
        self.post_batch_size = self.options['post_batch_size']
        self.num_iter = self.options['num_iter']
        self.div = self.options['div']
        # self.inner_round = self.options['inner_round']
        self.Projected = False

    def train(self):
        # build and fit model
        self.model_Y_give_X = self.fit_Y_give_X()
        self.model_A_give_X = self.fit_A_give_X()
        # self.model_A_give_XY = self.fit_A_give_XY()
        self.model_A_give_XY = None

        # init logger
        run_dir = make_run_dir(self.options)
        logger = JSONLStepLogger(run_dir, config={"lr": self.lr, "bs": self.batch_size}, flush_every=1)

        # compute estimate of marginals
        self.Pys = pd.crosstab(self.train_data.Y.ravel(),self.train_data.A.ravel(),rownames='Y',colnames='A')/len(self.train_data.Y)

        # Projection
        self.constraints = [(self.fair_metric, self.fair_bound)]
        G = self.buildG(self.train_data.X,self.constraints)

        fudge = 1e-4
        Py_x = self.model_Y_give_X.predict_proba(self.train_data.X)
        Py_x = (Py_x+fudge)/((Py_x+fudge).sum(axis=1,keepdims=True))
        self.l = self.admm(G, np.expand_dims(Py_x, axis=2), rho=2, max_iter=self.num_iter, report=True, div=self.div)

        # predict
        y_prob = np.squeeze(self.predict_proba(X=self.test_data.X, s=None), axis=2)
        print(f'predict_y shape:{y_prob.shape}')
        predicted = torch.multinomial(torch.tensor(y_prob), 1).view(-1)
        total_samples = len(self.test_data.Y)
        test_acc = predicted.eq(torch.tensor(self.test_data.Y).view(-1).long()).sum().item() / total_samples

        test_pred_class = predicted.numpy()

        test_diff, test_matrix = self.fair_evaluate(Y=self.test_data.Y.ravel(), pred_Y=test_pred_class.ravel(),A=self.test_data.A.ravel())
        print(f"[Eval] Task: fairprojection, test accuracy = {test_acc:.4f}, disparity = {test_diff:.4f}")
        logger.log_step(round='final', metrics={"acc": float(test_acc) ,"fairness_level": float(test_diff)},)



    # core ADMM optimization in numpy
    def admm(self, G,y,rho=2,div = 'kl', tol = 1e-6, max_iter = 1000,report=False):
        '''
        Core Model Projection algorithm. Here:
        n - number of data points
        k - number of constraints
        c - number of classes
        
        
        Arguments:
            - G (n x c x k np.array): constraint matrices for each n points
            - y (n x c x 1 np.array): original classifier outputs
            - rho: ADMM parameter
            - div: f-divergence, can be 'kl' or 'cross-entropy'
            - tol: primal-dual gap
            - report: print out primal-dual and feasibility gap
            
        Returns:
            - l (k x 1 np.array): optimal dual parameter lambda
        '''
        
        n,c,k = G.shape
        obj_v = []
        
        # Initialize variables and constants
        
        logy = np.log(y) # used in Step 1
        x = np.zeros((n,c,1)) # initialize x, used in Step 1 
        
        l = np.ones((k,1)) # initialize lambda
        
        v = np.ones((n,c,1)) # initialize v
        
        mu = np.ones((n,c,1)) # initialize mu

        G_t = np.transpose(G, axes=[0, 2, 1])
        # sum of G_i^T \times G_i across batch
        Q = np.sum(G_t @ G, axis = 0)/n 
        Q += 1e-6 * np.eye(Q.shape[0])

        # sym_err = np.max(np.abs(Q - Q.T))
        # eig_min = np.min(np.linalg.eigvalsh(0.5*(Q+Q.T)))
        # print("Q sym err:", sym_err, "min eig:", eig_min)
        
        rho2 = rho/2
        
        # optimization variables for step 2
        l_cp = cp.Variable(shape=(k,1),nonneg=True) # lambda for cvx
        d_cp = cp.Parameter(shape=(k,1))
        
        cost = rho2*cp.quad_form(l_cp,Q)+ d_cp.T @ l_cp  # cost function for step 2
        objective = cp.Minimize(cost)
        prob = cp.Problem(objective) # quadratic optimization
        
        
        for ix in range(max_iter):
            
            cv = mu + rho*(G @ l) 
        
            ### Step 1: v update ##
            inner_tol = 1e-13
            
            # kl-divergence
            if div == 'kl':
                a = cv - rho*logy
                x = np.zeros((n,c,1))

                for jx in range(50): # TODO: check this limit
                    xold = x
                    x= -(sp.special.softmax(x,axis=1)+a)/rho 
                    
                    # check if update is small
                    if np.abs(x-xold).max()<inner_tol:
                        break

                v = (x-logy) # update v

            
            # cross-entropy
            elif div == 'cross-entropy':
                # initialize z
                z = np.zeros((n,1,1))
                a = 4*rho*y
                
                for jx in range(50): # TODO: check this limit
                    cpz = cv + z
                    b = np.sqrt(a + cpz*cpz)
                    num = (-cpz + b)
                    gz = (-1 + num.sum(axis=1)*0.5).reshape(n,1,1)
                    gprime = -0.5*(num/b).sum(axis=1).reshape(n,1,1)
                    z = z - gz/gprime
                    
                    if np.abs(gz/gprime).max() < inner_tol: #break if update is small
                        break

                cpz = cv + z
                x = .5*(-cpz + np.sqrt(a+cpz*cpz))
                
                v = -(x+cv)/rho # update v
        
                
            # TODO: raise error if divergence not listed
                
            ### Step 2: lambda update ###
            d_cp.value = np.sum(G_t @ (mu + rho * v), axis=0) / n  # linear part
            prob.solve(warm_start=True)
            
            l_old = l
                
            l = l_cp.value # assign value to tensorflow variable

            ### Step 3: mu update ###
            mu +=  rho*(v + (G @ l) )
            
        # report gaps
        
        if report:

            # compute primal classifier
            h = self.predict(l,G,y,div = div)
            infeas = ((np.transpose(G,axes=[0,2,1])@np.array(h)).sum(axis=0)/n).max()
            # print('Max infeasibility: '+str(infeas))
            
            if div == 'kl':
                obj = -np.sum(sp.special.logsumexp(x,1))/n
                error = 100*np.abs((sp.special.kl_div(h,y).sum()/n-obj)/obj)
                
            elif div == 'cross-entropy':
                _, obj = self.predict_cross(-v,y,return_obj=True)
                obj = obj.sum()/n
                error = 100*np.abs((sp.special.kl_div(y,h).sum()/n-obj)/obj)

            # print('Error (percentage of dual): ' + str(error))

        return l


    def buildG(self, X,constraints_tol,y=None,s=None):
        '''
        Build constraint matrix. We will need to perturb Py_x so it is in the middle of the simplex.
        '''
        fudge = 1e-4

        
        if y is None:
            # if  y is not given, use trained models
            Py_x = self.model_Y_give_X.predict_proba(X)
            self.Py_x = Py_x

            
        else:
            #ys = np.array(['-'.join([str(yx),str(sx)]) for (yx,sx) in zip(y,s) ])
            #self.Pys_x = self.enc_ys.fit_transform(ys.reshape(-1,1))
            
            # use one-hot encoding of y to build matrix
            Py_x = self.enc_y.fit_transform(y.reshape(-1,1))
            
        
        # reshape Pys_x
        #self.Pys_x = self.Pys_x.reshape(len(self.Pys_x),self.Pys.shape[0],self.Pys.shape[1])
            
        # add fudge
        Py_x = (Py_x+fudge)/((Py_x+fudge).sum(axis=1,keepdims=True)) ## constant fudge
        ################################################
        # Py_x = Py_x + 1e-9 * np.random.uniform(low=1e-1, high=1.0, size=Py_x.shape) ##
        # Py_x = Py_x / Py_x.sum(axis=1, keepdims=True) ##
        ################################################

        # compute marginals
        Py = self.Pys.sum(axis=1).to_numpy()
  
                
        # normalize marginal
        #normPy_x = Py_x/Py.reshape(1,len(Py))
                
        # useful constants
        y_len = self.n_class
        s_len = self.n_group
        n_samples = X.shape[0]
        self.y_categories_ = np.array(list(self.Pys.index))
        self.s_categories_ = np.array(list(self.Pys))


        
        Glist = [] # list for storing constraints
        
        for (constraint,alpha) in constraints_tol:
            if constraint == 'eo':

                
                # if s not given, use trained model
                if s is None:
                    assert not (self.model_A_give_XY is None), "Fit classifier for predicting S from X and Y first!"
                

                for (yv,y_ix) in zip(self.y_categories_,range(y_len)):
                    
                    # initialize constraint matrix
                    Gp = np.zeros((n_samples, y_len,s_len))
                    Gm = np.zeros((n_samples, y_len,s_len))
                    
                    # prepare proabilities of group membership
                    if s is None:
                        y = np.array([yv for i in range(n_samples)])
                        Xy = np.concatenate((X,(torch.zeros(n_samples,y_len)) + torch.eye(y_len)[y]), axis=1)
                        ## X: (n, m), y: (n, c), Xy: (n, m+c) Ps_xy: (n, d)
                        
                        Ps_xy = self.model_A_give_XY.predict_proba(Xy)

                    else:
                        Ps_xy = self.enc_s.fit_transform(s.reshape(-1,1))

                    ################################################
                    # Ps_xy = Ps_xy + 1e-9 * np.random.uniform(low=1e-1, high=1.0, size=Ps_xy.shape)  ## fudge
                    # Ps_xy = Ps_xy / Ps_xy.sum(axis=1, keepdims=True)  ## fudge
                    ################################################

                    for (sv,s_ix) in zip(self.s_categories_,range(s_len)):
                        
                        # upper constraint
                        Gp[:,y_ix,s_ix] = Py_x[:,y_ix]*( (Ps_xy[:,s_ix]/self.Pys.loc[yv,sv])-((1+alpha)/Py[y_ix]  ) )
                        
                        # lower constraint
                        Gm[:,y_ix,s_ix] = Py_x[:,y_ix]*( -(Ps_xy[:,s_ix]/self.Pys.loc[yv,sv])+((1-alpha)/Py[y_ix]  ) )
                        
                    Glist.append(Gp)
                    Glist.append(Gm)
                    
            if constraint == 'dp':

                # if s not given, use trained model
                if s is None:
                    assert not (self.model_A_give_X is None), "Fit classifier for predicting S from X and Y first!"
                

                for (yv,y_ix) in zip(self.y_categories_,range(y_len)):
                    
                    # initialize constraint matrix
                    Gp = np.zeros((n_samples, y_len,s_len))
                    Gm = np.zeros((n_samples, y_len,s_len))
                    
                    # prepare proabilities of group membership
                    if s is None:
                        
                        Ps_x = self.model_A_give_X.predict_proba(X)
                        
                    else:
                        Ps_x = self.enc_s.fit_transform(s.reshape(-1,1))

                    ################################################
                    # Ps_x = Ps_x + 1e-9 * np.random.uniform(low=1e-1, high=1.0, size=Ps_x.shape) ## fudge
                    # Ps_x = Ps_x / Ps_x.sum(axis=1, keepdims=True)  ## fudge
                    ################################################

                    for (sv,s_ix) in zip(self.s_categories_,range(s_len)):
                        
                        # compute marginal
                        Ps = sum(self.Pys.loc[:,sv])
                        
                        # upper constraint
                        Gp[:,y_ix,s_ix] = ( (Ps_x[:,s_ix]/Ps) - (1+alpha) )
                        
                        # lower constraint
                        Gm[:,y_ix,s_ix] = ( -(Ps_x[:,s_ix]/Ps) + (1-alpha) )
                        
                    Glist.append(Gp)
                    Glist.append(Gm)

        G_temp = np.concatenate(Glist,axis=2)
        ################################################
        G_temp = G_temp + np.random.normal(loc=0.0, scale=1e-5, size=G_temp.shape)
        ################################################
        self.G = G_temp
        return G_temp

    
    def predict(self, l,G,y,div='kl'):
        '''
        Compute the corrected classifier output.
        
        
        Arguments:
            - l (k x 1 np.array): dual parameter lambda
            - G (n x c x k np.array): constraint matrix for the given data point 
            - y (n x c x 1 np.array): original classifier output
            - div: f-divergence, can be 'kl' or 'cross-entropy'
            
        Retruns:
            - h (n x c x 1 np.array): corrected prediction
        '''
        
        # create optimization variable
        n,c,k = G.shape
            
        # compute v
        v = G @ l
        
        
        if div == 'kl':        
            # kl cost
            h = sp.special.softmax(-v + np.log(y),axis=1)
        
        elif div == 'cross-entropy':
            # cross-entropy cost
            
            if n< 5000:
                h = self.predict_cross(v,y)         
            # if batch size is large, use multiprocess
            else:
                cores = 4
                # create batches of size 100
                n_list = range(n)
                it = iter(n_list)
                size = 100
                ln = list(iter(lambda: tuple(islice(it, size)), ())) # list of indices
                
                vals = [(v[ix,:,:],y[ix,:,:]) for ix in ln] # split of v and y values
                
                #compute in parallel
                with Pool(cores) as p:
                    hvals = (p.starmap(self.predict_cross, [(v[ix,:,:],y[ix,:,:]) for ix in ln]))
                    
                h = np.concatenate(hvals,axis=0)
        
        return h
    
    def predict_cross(self, v,y,tol=1e-10,alpha=.3,beta=.5,max_iter = 100,return_obj = False):
        '''
        Interior-point method for computing corrected classifier with cross-entropy objective.
        This is essentially algorithm 10.1 in Boyd and Vandenberghe.
        As usual, n is the batch size.
        
        Arguments:
        - v (n x c x 1 np.array): linear term in the conjugate
        - y (n x c x 1 np.array): original classifier output
        - tol: worst-case batch relative error between objective and optimal
        - alpha, beta: line-search parameters (see CVX book, Algorithm 9.2)
        - max_iter: maximum number of iterations
        - return_obj: if objective should be returned as well as a second argument
        
        Returns:
        - h (n x c x 1 np.array): corrected predictions
        
        '''
        
        yinv = 1/(y+tol) # we will frequently use y inverse, so we pre-compute to avoid division by 0
        n,c,_ = y.shape
        
        ############### auxiliary functions ##############
        def newtonStep(h):
            a = h*yinv
            b = h*a
            fp = v - (1/a)
            w = -np.sum(fp*b,axis=1)/np.sum(b,axis=1)
            w = w.reshape(n,1,1)
            step = -(fp+w)*b
            return step

        # objective
        def f(h):
            cr = np.sum(v*h,axis=1) -np.sum(y*np.log(h),axis=1)
            return cr

        # grad
        def fp(h):
            return v - y/h

        # compute newton decrement (10.12 in Boyd's book)
        def newton_decrement(h,step):
            lx = np.sqrt(np.sum(step*step*y/(h*h),axis=1))
            return lx.max()

        # vectorized line search
        def line_search(h,step,alpha=.25,beta=.5):
            t = np.ones((n,1,1)) # initialize t

            # make sure no entry becomes negative
            while True:
                hnew = h+step*t
                ix = (hnew.min(axis=1)<0)
                if sum(ix) == 0:
                    break
                else:
                    t[ix] = t[ix]*beta

            # now search until break condition is met
            delta = (fp(h)*step).sum(axis=1).reshape(n,1,1)
            obj = f(h).reshape(n,1,1)


            while True:
                hnew = h + step*t
                obj_inc = f(hnew).reshape(n,1,1)
                ix = (obj_inc > obj + alpha*t*delta)
                if sum(ix) == 0:
                    break
                else:
                    t[ix] = t[ix]*beta          

            return hnew
        ############### end of auxiliary functions ##############
        
        # main procedure
        # initialize at y
        h = y
        obj = f(h)
        
        # Newton's method
        for j in range(max_iter):
            step = newtonStep(h)
            min_dec = newton_decrement(h,step)

            if min_dec**2/2<tol:
                break
            else:
                h = line_search(h,step,alpha=alpha,beta=beta)

                obj = f(h) + np.sum(y*np.log(y),axis=1)
            
        if return_obj:
            return h, obj
        else:
            return h
        
    def predict_proba(self,X,y=None,s=None):
        '''
        Predict with projected model.
        '''
        
        fudge = 1e-4
        
        # print('...Building constraint matrix...')
        
        G = self.buildG(X,self.constraints,y=y,s=s)
        Py_x = self.model_Y_give_X.predict_proba(X)
        
        Py_x = (Py_x+fudge)/((Py_x+fudge).sum(axis=1,keepdims=True))
        
        
        # print('...Predicting...')
        
        return self.predict(self.l, G, np.expand_dims(Py_x,axis=2),div=self.div)
    

    # @torch.no_grad()
    # def model_eval(self, data):
    #     assert self.lamb is not None

    #     dataLoader = DataLoader(data, batch_size = self.post_batch_size, shuffle = False)
    #     test_correct = test_num = 0.0
    #     preds = []

    #     M_a_lamb = self.get_cal_matrix(self.lamb) # (a,m,m)
    #     for (x, y, a) in dataLoader:
    #         if self.gpu:
    #             x, y, a = x.to(self.device), y.to(self.device), a.to(self.device)
    #         batch_size = x.shape[0]
    #         p_Y_given_XA = torch.stack([self.model_Y_give_XA.predict_proba(torch.cat([x,torch.ones((batch_size,1)).to(self.device) * a], dim=1)
    #                                                             ) for a in range(self.n_group)]).to(self.device) # (a,N,m)
    #         p_A_given_X = self.model_A_give_X.predict_proba(x).to(self.device) # (N,a)
    #         score = torch.bmm(p_Y_given_XA, M_a_lamb).permute(1,0,2) # (N, a, m)
    #         beta = torch.bmm(p_A_given_X.unsqueeze(1), score).squeeze(1) # (N,m)
    #         exp_beta = torch.exp(beta / self.tau)  
    #         softmax = exp_beta / torch.sum(exp_beta, dim=1, keepdim=True)  
    #         predicted = torch.multinomial(softmax, 1).view(-1)

    #         preds.append(predicted.detach().cpu())
    #         correct = predicted.eq(y.view(-1).long()).sum().item()
    #         batch_size = y.size(0)

    #         test_correct += correct # total correct, not average
    #         test_num += batch_size 

    #     test_acc = test_correct / test_num
    #     pred_class = torch.cat(preds, dim=0).numpy()

    #     group_confusion_metrix = confusion_matrix(data.Y.ravel(), pred_class.ravel(), data.A.ravel(), n_classes=self.n_class, n_groups=self.n_group, normalize='all')
    #     if self.fair_metric == 'dp':
    #         DP_test = delta_sp(pred_class.ravel(), self.test_data.A.ravel(), n_classes=self.n_class, n_groups=self.n_group)
            
    #         # group-wise constraint
    #         matrix = np.zeros((self.n_group, self.n_class))
    #         class_counts = np.bincount(self.test_data.A.ravel().astype(np.int64))  # count occurrences of each class
    #         total_samples = len(data)
    #         sen_priors = class_counts / total_samples  # class prior probabilities
    #         for i in range (self.n_class):
    #             for j in range (self.n_group):
    #                 matrix[i,j] = np.sum( [ ( sen_priors[a] - (a == j) )*group_confusion_metrix[a,:,i] for a in range (self.n_class) ])
    #         diff = np.max(matrix)
    #         DP_test_cal = diff
    #         print(f'[Eval] DP cal:{DP_test}, DP dir:{DP_test_cal}.')
        
    #     return test_acc, diff, matrix