import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import minimize

class ODE2d:
    def __init__(self,z_i,dt,nT,g0,x0,f1,f2,f1_grad,f2_grad,save_data=True,x_speed=1.e0,subsample=1,beta=0.05):
        self.z_i             = z_i
        self.dt              = dt
        self.g0              = g0/np.sum(g0) # distribution of stationary distribution
        self.x               = x0
        self.x0              = np.copy(x0)
        self.f1              = f1
        self.f2              = f2
        self.f1_grad         = f1_grad
        self.f2_grad         = f2_grad
        if save_data:
            self.x_history   = np.zeros((int(np.ceil(nT/subsample)),2)) # save at every time step
        else:
            self.x_history   = np.zeros((2,2)) # save at 3 time steps
            self.saveat      = [np.floor(nT/5.),np.floor(nT/3.)]
        self.save_data       = save_data
        self.loss            = np.zeros(nT-1)#loss for x player
        self.x_speed         = x_speed
        self.subsample       = subsample # save every subsample time steps
        self.beta=beta
        

    def update_x(self,rho,t):
        '''
        Runge-Kutta 3rd order method
        eta is discretization step size
        '''
        dt = self.dt
        x  = self.x
        normalized_rho       = rho/np.sum(rho)
        normalized_rho_bar   = self.g0/np.sum(self.g0)
        z_i                  = self.z_i
        loss_ = lambda x: -np.tensordot(self.f1(x,z_i),normalized_rho,axes=2) - np.tensordot(self.f2(x,z_i),normalized_rho_bar,axes=2) + self.beta/2. * np.dot(x,self.x0)
        if self.x_speed=="best response":
            # solve explicitly
            jac = lambda x: -np.dot(self.f1_grad(x,z_i),normalized_rho) - np.dot(self.f2_grad(x,z_i),normalized_rho_bar) + self.beta*(x-self.x0)
            opt = minimize(loss_,1.,jac=jac,method="SLSQP")
            # save population loss for later
            
            if opt.success:
                x = opt.x[0]
                self.x = x
                self.loss[t-1] = loss_(x)
            else:
                print("x not optimized")
                print(opt.message)
        
        else:
            # compute gradient with weight given by x_speed
            eta_ = self.x_speed
            k1 = -self._get_grad(x,              rho) # +0
            k2 = -self._get_grad(x+dt/2*k1,      rho) # +h/2
            k3 = -self._get_grad(x-dt*k1+2*dt*k2,rho) # +h
            self.x  -= dt/6.*(k1 + 4*k2 + k3)*eta_
            # print("shape",loss_(x).shape)
            self.loss[t-1] = loss_(x)

        self.prob = self.f2(x,z_i)

        if self.save_data:
            if t % self.subsample==0:
                self.x_history[int(t/self.subsample),:] = np.copy(self.x)
        elif t == self.saveat[0]:
            self.x_history[0,:] = np.copy(self.x)
        elif t == self.saveat[1]:
            self.x_history[1,:] = np.copy(self.x)
        return self.x


    def _get_grad(self,current_x,rho):
        '''
        updates the value of x according to the ODE x_dot = -grad f(x,Z)
        '''

        grad_val_zero_population = self.f1_grad(current_x,self.z_i) 
        grad_val_one_population  = self.f2_grad(current_x,self.z_i)
        normalized_rho       = rho/np.sum(rho)
        normalized_rho_bar   = self.g0/np.sum(self.g0)
        # print("one",grad_val_one_population.shape)
        # print("zero",grad_val_zero_population.shape)
        total_grad = np.tensordot(grad_val_zero_population,normalized_rho,axes=[[0,1],[0,1]]) + np.tensordot(grad_val_one_population,normalized_rho_bar,axes=[[0,1],[0,1]])  - self.beta*(current_x-self.x0)
        return total_grad