from RatBrainModel import RatEst
from Simulation import Simulation
from Simulation import CanonicalRats
import numpy as np
from encodeData import Encode

from scipy.interpolate import interp1d

class Environment():
    def __init__(self,ratNum=0):

        self.ratNum=ratNum
        cRats = CanonicalRats()
        self.allRats = cRats.getCanonicalRat()
        self.sim=Simulation(self.allRats,ratNum=0)

        self.ratEst =RatEst()
        self.testS, self.testA, self.testDS = self.genTestingData()

        self.encoder=Encode()


    def step(self, action,step):
        print(action)
        out = self.sim.simulateRat(action)
        self.ratEst.updateModel(action, out)

        r1,error,bestParamError = self.ratEst.getReward(self.testS, self.testA, self.testDS)
        print("reward,error",r1,error)
        prevS_A = np.hstack((self.ratEst.actions,self.ratEst.out))

        Z = self.encoder.encode(self.ratEst.actions, self.ratEst.out)
        return (Z,prevS_A,step),r1,error

    def reset(self,ratNum=None,seed=None):
        ##will need to generate several starting points
        print('reset')
        self.ratEst.reset()
        np.random.seed(seed)
        if ratNum is None:
            i=np.random.randint(0,5)
        else:
            i=ratNum
        self.newSimulation(ratNum=i,seed=seed)

        burn_in = np.random.uniform(low=0.0, high=4.0, size=(1,1)).reshape(1,1)
        #burn_in = np.array(.3).reshape(1, 1)
        """burn_in1 = np.array(4).reshape(1,1)
        burn_in2=np.array(0).reshape(1,1)
        f=np.random.randint(low=0,high=2)
        print('f',f)
        if f<1:
            burn_in=burn_in1
        else:
            burn_in=burn_in2
            """

        print('burn_in',burn_in)

        burn_out = self.sim.simulateRat(burn_in)
        burn_out=np.array(burn_out)
        self.ratEst.updateModel(burn_in,burn_out)
        r1, error,bestParamError = self.ratEst.getReward(self.testS, self.testA, self.testDS)
        prevS_A = np.hstack((self.ratEst.actions, self.ratEst.out))

        Z = self.encoder.encode(self.ratEst.actions,self.ratEst.out)
        return (Z,prevS_A,0),r1,error


    def genTestingData(self):
        testParams = np.atleast_2d(np.linspace(0, 4, 50)).T
        testOutputs = self.sim.simulateRat(testParams)

        i = np.argmax(testOutputs)
        bestParams= testParams[i]

        return testParams,testOutputs,bestParams


    def newSimulation(self,ratNum=1,seed=None):
        self.ratNum=ratNum
        self.sim=Simulation(self.allRats,ratNum=ratNum,seed=seed)
        self.testS, self.testA, self.testDS = self.genTestingData()