import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
import torch.utils.data as utils
from torch.autograd import Variable
from matplotlib import pyplot as plt
import numpy as np
from sklearn.preprocessing import MinMaxScaler
import random
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, ConstantKernel as C
from sklearn.gaussian_process.kernels import DotProduct, WhiteKernel
from sklearn.gaussian_process.kernels import (RBF, Matern, RationalQuadratic,
                                              ExpSineSquared, DotProduct,
                                              ConstantKernel)
from sklearn.gaussian_process.kernels import ConstantKernel, Matern



class Damage:
    def __init__(self,states=np.array([]),actions=np.array([]),deltS=np.array([])):

        self.AnominalEst=np.array([[-0.0220400000000000,0.00142200000000000,0,-32.1700000000000,0,0,0,0,0,0,0],
            [-0.0575900000000000,-0.394300000000000,871,0,0,0,0,0,0,0,0],
            [-9.50000000000000e-05,-0.00179300000000000,-0.546800000000000,0,0,0,0,0,0,0,0],
            [0,0,1,0,0,0,0,0,0,0,0],
            [0,0,0,0,-0.0639900000000000,0,-1,0.0369100000000000,0,0,0],
            [0,0,0,0,-2.11100000000000,-0.504800000000000,0.193000000000000,0,0,0,0],
            [0,0,0,0,0.773200000000000,0.0108700000000000,-0.179400000000000,0,0,0,0],
            [0,0,0,0,0,1,0,0,0,0,0],
            [0,0,0,0,0,0,1,0,0,0,0],
            [1,0,0,0,0,0,0,0,0,0,0],
            [0,1,0,0,0,0,0,0,0,0,0]])
        self.BnominalEst=np.array([[0,1.91000000000000e-05,0,0],
            [-18.5800000000000,0,0,0],
            [-1.20000000000000,0,0,0],
            [0,0,0,0],
            [0,0,0,0.00839000000000000],
            [0,0,0.184800000000000,0.105600000000000],
            [0,0,-0.00869900000000000,-0.459100000000000],
            [0,0,0,0],
            [0,0,0,0],
            [0,0,0,0],
            [0,0,0,0]])
        self.A=np.zeros((11,11))
        self.B=np.zeros((11,4))
        self.Asigma=np.zeros((11,11))
        self.Bsigma=np.zeros((11,4))

        self.states=states
        self.actions=actions
        self.deltS=deltS

        self.errorTrain=None
        self.stateError=None

        self.r_prev=3000

        self.A1=[]
        self.A2=[]
        self.B1=[]
        self.B2=[]
        self.A3=[]
        self.B3=[]



    def updateModel(self,newinputsStates=np.array([]),newinputsActions=np.array([]),newoutputs=np.array([])):
        if newinputsStates.shape[0] == 0 and self.states.shape[0] == 0:
            raise ValueError('No data')
        elif newinputsStates.shape[0] == 0:
            inputsStates = self.states
            inputsActions = self.actions
            outputs = self.deltS
        elif self.states.shape[0] == 0:
            newinputsActions = newinputsActions.reshape(4, -1)
            newinputsStates = newinputsStates.reshape(11, -1)
            inputsStates = newinputsStates
            inputsActions = newinputsActions
            outputs = newoutputs.reshape(11, -1)
        else:
            newinputsActions = newinputsActions.reshape(4, -1)
            newinputsStates = newinputsStates.reshape(11, -1)
            newoutputs = newoutputs.reshape(11, -1)
            inputsStates = np.hstack((self.states, newinputsStates))
            inputsActions = np.hstack((self.actions, newinputsActions))
            outputs = np.hstack((self.deltS, newoutputs))

        #deltaSEst = self.Nominal(inputsStates, inputsActions)
        inputs = np.vstack((inputsActions, inputsStates))

        kernel=ConstantKernel(1.0)  * RationalQuadratic(length_scale=1.0, alpha=0.1)



        gpEst = GaussianProcessRegressor(kernel=kernel, n_restarts_optimizer=9, alpha=.5,random_state=3)

        print(inputs.reshape(-1,15).shape,outputs.reshape(-1,11).shape)
        gpEst.fit(inputs.reshape(-1,15), outputs.reshape(-1,11))

        self.A=gpEst

        self.states=inputsStates
        self.actions=inputsActions
        self.deltS=outputs

    def getReward(self,testStates,testActions,testDeltaS):
        #deltaSEst = self.Nominal(testStates, testActions)
        dS=self.A.predict((np.vstack((testActions,testStates)).reshape(-1,15)))
        print("H1",dS.reshape(11,-1))
        print("H2",testDeltaS)
        r=np.sum(np.linalg.norm(dS.reshape(11,-1) - testDeltaS, axis=0))/testStates.shape[1]
        #print("here",dS.reshape(11,-1),'here2',testDeltaS)
        rNew=(self.r_prev-r)/self.r_prev
        self.r_prev=r
        return rNew,r

    def Nominal(self,state,action):
        action=action.reshape(4,-1)
        state=state.reshape(11,-1)
        deltaS = np.dot(self.AnominalEst, state) + np.dot(self.BnominalEst, action)
        return deltaS




    def reset(self):
        self.states = np.array([])
        self.actions = np.array([])
        self.deltS = np.array([])
        self.r_prev=3000


