from RatBrainModel import RatEst
from Simulation import Simulation
from Simulation import CanonicalRats
import numpy as np
from encodeData import Encode
import torch

from scipy.interpolate import interp1d

class Environment():
    def __init__(self,ratNum=0,optimizeModel=False,bootNum=2):

        self.ratNum=ratNum
        cRats = CanonicalRats()
        self.allRats = cRats.getCanonicalRat()
        self.sim=Simulation(self.allRats,ratNum=0)

        self.optimizeModel = optimizeModel
        self.ratEst1 =RatEst(optimizeModel=self.optimizeModel)
        self.ratEst2 = RatEst(optimizeModel=self.optimizeModel)
        self.ratEst3 = RatEst(optimizeModel=self.optimizeModel)

        self.allRatsEst=[]
        self.allRatsEst.append(self.ratEst1)
        for i in range(0,bootNum):
            self.allRatsEst.append(RatEst(optimizeModel=self.optimizeModel))

        self.testS, self.testA, self.testbestP,self.testBestDA = self.genTestingData()


        self.encoder=Encode()



    def step(self, action):
        #print(action)
        out = self.sim.simulateRat(action)


        if self.optimizeModel==True:
            self.ratEst1.updateModel(action, out)
            r1,error,bestParamError = self.ratEst1.getRewardModel(self.testS, self.testA, self.sim.bestP)
            prevS_A = np.hstack((self.ratEst1.actions, self.ratEst1.out))
            self.ratEst2.updateModel(action, out)


        else:
            for i in range(len(self.allRatsEst)):
                self.allRatsEst[i].updateModel(action,np.multiply(action, out))
            r1, error,DAError = self.ratEst1.getRewardParam(self.testS, self.testA, self.sim.bestP,self.testBestDA)
            bestParamError=error
            prevS_A = np.hstack((self.ratEst1.actions, self.ratEst1.out,self.ratEst1.prevBestParams))

        #print("reward,error",r1,error)


        Z = self.encoder.encode(self.ratEst1.actions, self.ratEst1.out)
        return (Z,torch.from_numpy(prevS_A),torch.from_numpy(np.array(self.ratEst1.BestGuess)),torch.from_numpy(np.array(self.sim.ratNum))),r1,error,bestParamError,DAError

    def reset(self,ratNum=None,seed=None):
        ##will need to generate several starting points
        #print('reset')

        np.random.seed(seed)
        if ratNum==None:
            rNum=np.random.randint(0,4)
        else:
            rNum=ratNum
        self.newSimulation(ratNum=rNum,seed=seed)
        for i in range(len(self.allRatsEst)):
            self.allRatsEst[i].reset(self.sim)


        #burn_in = np.random.uniform(low=0.0, high=4.0, size=(1,1)).reshape(1,1)
        burn_in=np.array([4.0]).reshape(-1,1)
        #burn_in = np.array(.3).reshape(1, 1)
        """burn_in1 = np.array(4).reshape(1,1)
        burn_in2=np.array(0).reshape(1,1)
        f=np.random.randint(low=0,high=2)
        print('f',f)
        if f<1:
            burn_in=burn_in1
        else:
            burn_in=burn_in2
            """

        #print('burn_in',burn_in)

        burn_out = self.sim.simulateRat(burn_in)
        burn_out=np.array(burn_out)

        if self.optimizeModel==True:
            self.ratEst1.updateModel(burn_in, burn_out)
            self.ratEst2.updateModel(burn_in, burn_out)
            r1,error,bestParamError = self.ratEst1.getRewardModel(self.testParams, self.testOut, self.sim.bestP)
            prevS_A = np.hstack((self.ratEst1.actions, self.ratEst1.out))
        else:
            for i in range(len(self.allRatsEst)):
                self.allRatsEst[i].updateModel(burn_in, np.multiply(burn_in, burn_out))

            r1, error,DAError = self.ratEst1.getRewardParam(self.testParams, self.testOut, self.sim.bestP,self.testBestDA)
            #print(self.ratEst1.actions.shape,self.ratEst1.out.shape,self.ratEst1.prevBestParams.shape)
            prevS_A = np.hstack((self.ratEst1.actions, self.ratEst1.out, self.ratEst1.prevBestParams))



        Z = self.encoder.encode(self.ratEst1.actions,self.ratEst1.out)
        return (Z,torch.from_numpy(prevS_A),torch.from_numpy(np.array(self.ratEst1.BestGuess)),torch.from_numpy(np.array(rNum))),r1,error,DAError


    def genTestingData(self):
        numTest=500
        testParams = np.atleast_2d(np.linspace(0, 4, numTest)).T
        testOutputs = self.sim.predictRat(testParams)

        i = np.argmax(np.multiply(testOutputs.reshape(numTest,),testParams.reshape(numTest,)))
        bestParams= testParams[i]
        bestDA=np.max(np.multiply(testOutputs.reshape(numTest,),testParams.reshape(numTest,)))

        return testParams,testOutputs,bestParams,bestDA


    def newSimulation(self,ratNum=1,seed=None):
        self.ratNum=ratNum
        self.sim=Simulation(self.allRats,ratNum=ratNum,seed=seed)
        self.testParams, self.testOut, self.testbestP,self.testBestDA = self.genTestingData()