from Bao.RatBrainBayesian import RatBayesian
from Simulation import Simulation
from Simulation import CanonicalRats
import numpy as np
from encodeData import Encode

from scipy.interpolate import interp1d

class BayesianEnvironment():
    def __init__(self,ratNum=0,optimizeModel=True):

        self.optimizeModel = optimizeModel
        self.ratNum=ratNum
        cRats = CanonicalRats()
        self.allRats = cRats.getCanonicalRat()
        self.sim=Simulation(self.allRats,ratNum=0)

        self.optimizeModel=optimizeModel
        self.ratEst =RatBayesian(optimizeModel=self.optimizeModel)
        self.testS, self.testA, self.testDS,self.testBestDA = self.genTestingData()

        self.encoder=Encode()



    def step(self, action,modelReward=True):
        out = self.sim.simulateRat(action)

        if self.optimizeModel==True:
            self.ratEst.updateModel(action, out)
            r1,error,bestParamError = self.ratEst.getRewardModel(self.testS, self.testA, self.sim.bestP)
        else:
            self.ratEst.updateModel(action, np.multiply(action,out))
            r1, error,DAError = self.ratEst.getRewardParam(self.testS, self.testA, self.sim.bestP,self.testBestDA)
            bestParamError=error

        return r1,error,bestParamError,DAError

    def reset(self,ratNum=None,seed=None,sim=None):
        ##will need to generate several starting points
        self.ratEst.reset()
        np.random.seed(seed)
        if ratNum==None:
            i=np.random.randint(0,5)
        else:
            i=ratNum
        self.newSimulation(ratNum=i,seed=seed,sim=sim)
        self.ratEst.reset(self.sim)

        #burn_in = np.random.uniform(low=0.0, high=4.0, size=(1,1)).reshape(1,1)
        burn_in = np.array(2.5).reshape(1, 1)

        burn_out = self.sim.simulateRat(burn_in)
        burn_out=np.array(burn_out)
        if self.optimizeModel == True:
            self.ratEst.updateModel(burn_in, burn_out)
            r1, error, bestParam = self.ratEst.getRewardModel(self.testS, self.testA, self.sim.bestP)
        else:
            self.ratEst.updateModel(burn_in, np.multiply(burn_in, burn_out))
            r1, error,DAError = self.ratEst.getRewardParam(self.testS, self.testA, self.sim.bestP,self.testBestDA)

        return r1,error,DAError


    def genTestingData(self):
        numTest=500
        testParams = np.atleast_2d(np.linspace(0, 4, numTest)).T
        testOutputs = self.sim.predictRat(testParams)

        i = np.argmax(np.multiply(testOutputs.reshape(numTest, ), testParams.reshape(numTest, )))
        bestParams= testParams[i]
        bestDA = np.max(np.multiply(testOutputs.reshape(numTest, ), testParams.reshape(numTest, )))

        return testParams,testOutputs,bestParams, bestDA


    def newSimulation(self,ratNum=1,seed=None,sim=None):
        self.ratNum=ratNum
        if sim==None:
            self.sim=Simulation(self.allRats,ratNum=ratNum,seed=seed)
        else:
            self.sim=sim
        self.testS, self.testA, self.testDS,self.testBestDA = self.genTestingData()