import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
import torch.utils.data as utils
from torch.autograd import Variable
from MILPs.MILP_RatBrain import solve_br
from matplotlib import pyplot as plt
import numpy as np
from sklearn.preprocessing import MinMaxScaler
import random
from RatBrainModel import RatEst
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, ConstantKernel as C
from sklearn.gaussian_process.kernels import DotProduct, WhiteKernel
from sklearn.gaussian_process.kernels import (RBF, Matern, RationalQuadratic,
                                              ExpSineSquared, DotProduct,
                                              ConstantKernel)
from sklearn.gaussian_process.kernels import ConstantKernel, Matern



class RatBayesian:
    def __init__(self,states=np.array([]),actions=np.array([]),out=np.array([]),optimizeModel=True):
        self.state_size=1
        self.action_size=1
        self.output_size=1

        self.states=states
        self.actions=actions
        self.out=out
        self.maxOut=0


        self.r_prev=1
        self.prev_error=20

        self.rat=[]

        self.model=[]

        self.optimizeModel = optimizeModel
        self.sim = []
        self.prevBestParams = np.array([])



    def updateModel(self,newinputsActions=np.array([]),newoutputs=np.array([])):
        torch.manual_seed(7)
        np.random.seed(7)
        if newinputsActions.shape[0]==0 and self.actions.shape[0]==0:
            raise ValueError('No data')
        elif newinputsActions.shape[0]==0:
            inputsActions=self.actions
            outputs=self.out
        elif self.actions.shape[0]==0:
            newinputsActions = newinputsActions.reshape( -1,self.action_size)
            inputsActions=newinputsActions
            outputs=newoutputs.reshape(-1,self.output_size)
        else:
            newinputsActions = newinputsActions.reshape(-1,self.action_size)
            newoutputs = newoutputs.reshape(-1,self.output_size)
            inputsActions=np.vstack((self.actions,newinputsActions))
            outputs=np.vstack((self.out,newoutputs))
        #retrain neural network

        kernel1 = 1.0 * ExpSineSquared(length_scale=1.0, periodicity=3.0,
                                       length_scale_bounds=(0.1, 10.0),
                                       periodicity_bounds=(1.0, 10.0))
        kernelM = Matern(length_scale=1.0, nu=1.5)
        gpEst = GaussianProcessRegressor(kernel=kernelM, n_restarts_optimizer=9, alpha=.01, random_state=2)

        gpEst.fit(inputsActions, outputs)

        x_test = np.atleast_2d(np.linspace(0, 4, 100)).T
        y_test,sigma=gpEst.predict(x_test,return_std=True)
        y_test=y_test.reshape(100,)
        """plt.figure()
        plt.plot(x_test,y_test,color='red')
        plt.fill(np.concatenate([x_test, x_test[::-1]]),
                 np.concatenate([y_test - 1.9600 * sigma,
                                 (y_test + 1.9600 * sigma)[::-1]]),
                 alpha=.1, fc='b', ec='None', label='95% confidence interval', )
        plt.scatter(inputsActions,outputs)
        plt.show()
        """


        self.rat=gpEst
        self.actions=inputsActions
        self.out=outputs


        newBestParam,bestDA = self.get_bestParam()
        if self.prevBestParams.shape[0] == 0:
            self.prevBestParams = newBestParam.reshape(1, 1)
        else:
            self.prevBestParams = np.vstack((self.prevBestParams, newBestParam.reshape(1, 1)))

        allbestParamOut = self.sim.simulateRat(self.prevBestParams)
        i = np.argmax(np.multiply(allbestParamOut.reshape(-1, ), self.prevBestParams.reshape(-1, )))
        self.BestGuess = self.prevBestParams[i]

    def getRewardModel(self,testParams,testOut,testBestParams):
        #return improvement in model accuracy and best parameter accuracy
        #must maximize MILP with model to get best parameter accuracy
        testParams=torch.Tensor(testParams)
        out_est=self.rat.predict(testParams)
        out_est=np.array(out_est)
        rModelError = np.sum(np.linalg.norm(out_est - testOut.reshape(-1,1),axis=0)) / np.sqrt(testParams.shape[0])
        #print('hereoutest',out_est)
        #print('testout',testOut)
        #print(out_est - testOut.reshape(-1,1))
        #print(np.linalg.norm(out_est - testOut.reshape(-1,1),axis=0))

        bestParam,bestDA=self.get_bestParam()
        bestParamError=abs(bestParam-testBestParams)


        r=rModelError
        rNew = (self.r_prev - r) / self.r_prev
        self.r_prev = r
        return rNew, r, bestParamError


    def getRewardParam(self,testParams,testOut,testBestParams,testbestDA):
        #return improvement in model accuracy and best parameter accuracy
        #must maximize MILP with model to get best parameter accuracy
        testParams=torch.Tensor(testParams)
        bestParam,bestDA=self.get_bestParam()
        bestParamError=abs(bestParam-testBestParams)
        bestDAError = abs(bestDA - testbestDA)

        if self.prev_error<.001:
            rNew=(self.prev_error-bestParamError)
        else:
            rNew=(self.prev_error-bestParamError)/self.prev_error
        print("GT best param", testBestParams)
        print("bestEst", bestParam)
        self.prev_reward=rNew
        self.prev_error = bestParamError
        return rNew, bestParamError,bestDAError

    def get_bestParam(self):
        numTest=500
        x = np.atleast_2d(np.linspace(0, 4, numTest)).T
        y_pred,sigma=self.rat.predict(x, return_std=True)

        if self.optimizeModel==True:
            obj = np.mutiply(x,y_pred)
        else:
            obj=y_pred
        i=np.argmax(obj)
        bestParam=x[i]
        bestDA = np.multiply(self.sim.predictRat(bestParam.reshape(-1, 1)), bestParam)


        return bestParam,bestDA





    def reset(self,sim=[]):
        self.states = np.array([])
        self.actions = np.array([])
        self.out=np.array([])

        self.r_prev=1
        self.prev_error = 20
        self.rat=[]
        self.sim=sim
        self.prevBestParams = np.array([])



