import numpy as np
import matplotlib.pyplot as plt

class PDE2d:
    def __init__(self,dx,dy,N,nT,H_prime_rho=None,H_prime=None,V=None,W=None,save_data=True):
        '''dz: spacing between z coordinates for rho
           N:  number of points less than (and greater than) zero where z is defined
           nT: number of time steps for simulation'''

        x_start   = -(N+0.5)*dx
        x_stop    = (N+0.5)*dx
        y_start   = -(N+0.5)*dy
        y_stop    = (N+0.5)*dy
        rx        = np.arange(x_start,x_stop+0.0001,dx)
        ry        = np.arange(y_start,y_stop+0.0001,dy)
        self.rx   = rx
        self.ry   = ry
        nRho_x    = len(rx)
        nRho_y    = len(ry)
        z_x,z_y   = np.meshgrid(rx,ry)
        z_i_list_ = np.stack([z_x,z_y],axis=-1)
        # save parameters
        self.theta       = 2.
        self.dx          = dx
        self.dy          = dy
        self.z_i         = z_i_list_
        self.z_x         = z_x
        self.z_y         = z_y
        self.nT          = nT
        self.H_prime_rho = H_prime_rho # diffusion function times rho
        self.H_prime     = H_prime
        self.V           = V # potential function 
        self.W           = W # kernel function 
        self.save_data   = save_data
        self.loss        = np.zeros(nT-1)
        if save_data:
            self.rho_history = np.zeros((nRho_x,nRho_y,nT))
        else:
            self.rho_history   = np.zeros((nRho_x,nRho_y,2)) # save at 3 time steps
            self.saveat      = [np.floor(nT/5.),np.floor(nT/3.)] # these are the time steps
        self.save_data = save_data

    
    def set_initial_distribution(self,f,g=None):
        '''f: function describing the PDF of the initial distribution for the negative labels
           g: function describing the PDF of the initial distribution for the positive labels'''
        self.rho0            = f(self.z_x,self.z_y)
        self.rho             = f(self.z_x,self.z_y) # distribution that we are modeling
        if hasattr(g, '__call__'):
            self.g0              = g(self.z_x,self.z_y) # used for the gradient of x


    # def update_RK(self,x,t,dt,samples):
    #     rho_bar  = self.rho
    #     k1       = self._F(rho_bar,              x) # +0
    #     k2       = self._F(rho_bar+dt/2*k1,      x) # +h/2
    #     k3       = self._F(rho_bar-dt*k1+2*dt*k2,x) # +h
    #     rho_bar += dt/6.*(k1 + 4*k2 + k3)
    #     # volume = np.trapz(np.trapz(rho_bar,dx=self.dx,axis=0),dx=self.dy,axis=0)
    #     # print("rho_bar mass",t,volume)
    #     # rho_bar /= volume
    #     self.rho = rho_bar
    #     self.compute_loss(samples,x,t)
    #     if self.save_data:
    #         self.rho_history[:,:,t] = rho_bar
    #     return rho_bar
    def update_RK(self,x,t,dt,update_type=None): #,samples
        
        if update_type=="best response":
            iters_ = 50 # let function optimize more
        else:
            iters_ = 1 # single update
        
        rho_bar  = self.rho
        for _ in range(iters_):
            k1       = self._F(rho_bar,              x) # +0
            k2       = self._F(rho_bar+dt/2*k1,      x) # +h/2
            k3       = self._F(rho_bar-dt*k1+2*dt*k2,x) # +h
            rho_bar += dt/6.*(k1 + 4*k2 + k3)
            # print(rho_bar.shape,self.ry.shape,self.ry.shape)
            # print(np.trapz(rho_bar, self.ry, axis=1).shape)
            total_   = np.trapz(np.trapz(rho_bar, self.ry, axis=1), self.rx, axis=0)
            # print("sum",np.trapz(rho_bar,self.z_i))
            rho_bar /= total_
        self.rho = rho_bar
        self.compute_loss(x,t)
        if self.save_data:
            # if t % self.subsample==0:
            #     self.rho_history[:,int(t/self.subsample)] = rho_bar
            self.rho_history[:,:,t] = rho_bar
        elif t == self.saveat[0]:
            self.rho_history[:,:,0] = np.copy(rho_bar)
        elif t == self.saveat[1]:
            self.rho_history[:,:,1] = np.copy(rho_bar)
        return rho_bar


    def compute_loss(self,x,t):
        '''compute cost of negative labels for the distribution
        not sure if this is correct...'''
        normalized_rho       = self.rho/np.sum(self.rho)
        self.loss[t-1] = np.tensordot(self.V(x,self.z_i),normalized_rho)
    # def compute_loss(self,samples,x,t):
    #     '''compute cost of negative labels for the distribution
    #     not sure if this is correct...'''
    #     xa       = x[:2].reshape((-1,1))
    #     xb       = x[2:].reshape((-1,1))
    #     samples  = np.reshape(samples,(-1,2))
    #     exp_term = np.exp(-xa.T @ (samples.T-xb))
    #     prob     = 1.-np.around((1. + exp_term)**-1,decimals=16)
    #     self.loss[t-1] = np.mean(prob)


    def _get_drho_dt(self,flux_list_x,flux_list_y):
        dx = self.dx 
        dy = self.dy 
        ax1_length     = np.shape(flux_list_x)[1]
        ay0_length     = np.shape(flux_list_y)[0]
        flux_plus_x    = np.insert(flux_list_x,-1,np.zeros(ax1_length),axis=0)
        flux_minus_x   = np.insert(flux_list_x, 0,np.zeros(ax1_length),axis=0)
        drho_dt_list_x = -(flux_plus_x-flux_minus_x)/dx
        flux_plus_y    = np.insert(flux_list_y,-1,np.zeros(ay0_length),axis=1)
        flux_minus_y   = np.insert(flux_list_y, 0,np.zeros(ay0_length),axis=1)
        drho_dt_list_y = -(flux_plus_y-flux_minus_y)/dy
        return drho_dt_list_x + drho_dt_list_y


    def _minmod(self,dz,rho_bar_,rho_bar_plus_1_,rho_bar_minus_1_):
        '''Returns rho_x_j'''
        # dz is either dx or dy depending on the function call
        theta_ = self.theta
        q1 = theta_ * (rho_bar_plus_1_-rho_bar_) / dz
        q2 = (rho_bar_plus_1_-rho_bar_minus_1_)/(2*dz)
        q3 = theta_ * (rho_bar_-rho_bar_minus_1_) / dz

        negative_idx = (q1 < 0) & (q2 < 0) & (q3 < 0)
        positive_idx = (q1 > 0) & (q2 > 0) & (q3 > 0)

        minmod_array = np.zeros(np.shape(q1))
        minmod_array[negative_idx] = np.maximum(np.maximum(q1[negative_idx],q2[negative_idx]),q3[negative_idx])
        minmod_array[positive_idx] = np.minimum(np.minimum(q1[positive_idx],q2[positive_idx]),q3[positive_idx])
        minmod_array[0] = theta_ * (rho_bar_plus_1_[0]-rho_bar_[0]) / dz
        minmod_array[-1] = theta_ * (rho_bar_[-1]-rho_bar_minus_1_[-1]) / dz

        return minmod_array


    def _get_minmod(self,rho_bar_):
        dx  = self.dx 
        dy  = self.dy 

        x_size           = np.shape(rho_bar_)[1]
        rho_bar_plus_1_  = np.insert(rho_bar_[1:,:],-1,np.zeros(x_size),axis=0) 
        rho_bar_minus_1_ = np.insert(rho_bar_[:-1,:],0,np.zeros(x_size),axis=0)
        rho_x = self._minmod(dx,rho_bar_,rho_bar_plus_1_,rho_bar_minus_1_)

        y_size           = np.shape(rho_bar_)[0]
        rho_bar_plus_1_  = np.insert(rho_bar_[:,1:],-1,np.zeros(y_size),axis=1)
        rho_bar_minus_1_ = np.insert(rho_bar_[:,:-1],0,np.zeros(y_size),axis=1)
        rho_y = self._minmod(dy,rho_bar_,rho_bar_plus_1_,rho_bar_minus_1_)
        
        return rho_x,rho_y


    def _compute_F_and_u(self,xi_list_,rho_bar,rho_x_,rho_y_):
        dx        = self.dx
        dy        = self.dy

        # xi_plus_one_u = xi_list_[1:,:]
        # u = -(xi_plus_one_u-xi_list_[:-1,:])/dx
        # H_prime_effect = - self.H_prime(rho_bar[1:,:] / rho_bar[:-1,:]) / dx
        # u += H_prime_effect
        # u_plus_  = np.maximum(u,0)
        # u_minus_ = np.minimum(u,0)
        # rho_west  = rho_bar - dx*rho_x_/2
        # rho_east  = rho_bar + dx*rho_x_/2
        # rho_north = rho_bar - dy*rho_y_/2
        # rho_south = rho_bar + dy*rho_y_/2
        # xi_plus_one_v = xi_list_[:,1:]
        # v = -(xi_plus_one_v-xi_list_[:,:-1])/dy
        # H_prime_effect = - self.H_prime(rho_bar[:,1:] / rho_bar[:,:-1]) / dy
        # v += H_prime_effect
        # v_plus_  = np.maximum(v,0)
        # v_minus_ = np.minimum(v,0)

        # rho_e = rho_east[:-1,:]
        # rho_w = rho_west[1:,:]
        # rho_n = rho_north[:,:-1]
        # rho_s = rho_south[:,1:]
        # Fx    = u_plus_*rho_e + u_minus_*rho_w
        # Fy    = v_plus_*rho_n + v_minus_*rho_s
        # return Fx, Fy
        rho_tilde = self.rho0

        rho_west  = rho_bar - dx*rho_x_/2
        rho_east  = rho_bar + dx*rho_x_/2
        rho_north = rho_bar - dy*rho_y_/2
        rho_south = rho_bar + dy*rho_y_/2

        xi_plus_one_u = xi_list_[1:,:]
        u = -(xi_plus_one_u-xi_list_[:-1,:])/dx
        for_u_plus  = u*rho_east[:-1,:]
        for_u_minus = u*rho_west[1:,:]

        for_u_plus  -= self.H_prime_rho(rho_east[:-1,:],rho_bar[1:,:],rho_tilde[1:,:]) / dx - self.H_prime_rho(rho_east[:-1,:],rho_bar[:-1,:],rho_tilde[:-1,:]) / dx
        for_u_minus -= self.H_prime_rho(rho_west[1:,:],rho_bar[1:,:],rho_tilde[1:,:]) / dx - self.H_prime_rho(rho_west[1:,:],rho_bar[:-1,:],rho_tilde[:-1,:]) / dx

        u_plus_  = np.maximum(for_u_plus,0)
        u_minus_ = np.minimum(for_u_minus,0)
        
        xi_plus_one_v = xi_list_[:,1:]
        v = -(xi_plus_one_v-xi_list_[:,:-1])/dy

        for_v_plus  = v*rho_north[:,:-1]
        for_v_minus = v*rho_south[:,1:]

        for_v_plus  -= self.H_prime_rho(rho_north[:,:-1],rho_bar[:,1:],rho_tilde[:,1:]) / dy - self.H_prime_rho(rho_north[:,:-1],rho_bar[:,:-1],rho_tilde[:,:-1]) / dy
        for_v_minus -= self.H_prime_rho(rho_south[:,1:],rho_bar[:,1:],rho_tilde[:,1:]) / dy - self.H_prime_rho(rho_south[:,1:],rho_bar[:,:-1],rho_tilde[:,:-1]) / dy
        # H_prime_effect = - self.H_prime(rho_bar[:,1:] / rho_bar[:,:-1]) / dy
        # v += H_prime_effect
        v_plus_  = np.maximum(for_v_plus,0)
        v_minus_ = np.minimum(for_v_minus,0)

        # rho_e = rho_east[:-1,:]
        # rho_w = rho_west[1:,:]
        # rho_n = rho_north[:,:-1]
        # rho_s = rho_south[:,1:]
        # Fx    = u_plus_*rho_e + u_minus_*rho_w
        # Fy    = v_plus_*rho_n + v_minus_*rho_s
        return u_plus_ + u_minus_, v_plus_ + v_minus_



    def _evaluate_W(self,rho_bar): #let's assume this is symmetric for now, ie only dependent on x_jj-x_ii, y_kk-y_ll, and not x_jj
        z_i      = self.z_i
        dx       = self.dx
        dy       = self.dy
        x_len = np.shape(rho_bar)[0]
        y_len = np.shape(rho_bar)[1]
        xi_list = np.zeros((x_len,y_len))

        for jj in range(x_len):
            for kk in range(y_len):
                total_ = 0
                xj = z_i[0,jj,0]
                yk = z_i[kk,0,1]
                dist_x = np.abs(xj-z_i[:,:,0])
                dist_y = np.abs(yk-z_i[:,:,1])
                total_ = np.sum(self.W(dist_x,dist_y,dx,dy)*rho_bar)                
                xi_list[jj,kk] = total_
        
        return xi_list


    def _F(self,rho_bar,x):
        '''Return F(rho) for RK method
        x is the external influence'''
        xi              = self.V(x,self.z_i) + self._evaluate_W(rho_bar) # we are using H_prime but can't use this discretization
        rho_x_,rho_y_   = self._get_minmod(rho_bar) # get_minmod(drho,theta_,rho_bar)
        flux_x,flux_y   = self._compute_F_and_u(xi,rho_bar,rho_x_,rho_y_)
        drhodt          = self._get_drho_dt(flux_x,flux_y) #get_drho_dt(drho,flux_x)
        return drhodt
