import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
import torch.utils.data as utils
from torch.autograd import Variable
from MILPs.MILP_RatBrain import solve_br
from matplotlib import pyplot as plt

np.random.seed()
seed = np.random.uniform()
torch.manual_seed(seed)





class RatEst:
    def __init__(self,states=np.array([]),actions=np.array([]),out=np.array([]),optimizeModel=True):
        self.state_size=1
        self.action_size=1
        self.output_size=1

        self.states=states
        self.actions=actions
        self.out=out
        self.maxOut=0

        self.prevBestParams = np.array([])


        self.r_prev=1
        self.prev_error=20

        self.rat=[]

        self.model=[]
        self.optimizeModel=optimizeModel
        self.sim=[]



    def updateModel(self,newinputsActions=np.array([]),newoutputs=np.array([])):

        if newinputsActions.shape[0]==0 and self.actions.shape[0]==0:
            raise ValueError('No data')
        elif newinputsActions.shape[0]==0:
            inputsActions=self.actions
            outputs=self.out
        elif self.actions.shape[0]==0:
            newinputsActions = newinputsActions.reshape( -1,self.action_size)
            inputsActions=newinputsActions
            outputs=newoutputs.reshape(-1,self.output_size)
        else:
            newinputsActions = newinputsActions.reshape(-1,self.action_size)
            newoutputs = newoutputs.reshape(-1,self.output_size)
            inputsActions=np.vstack((self.actions,newinputsActions))
            outputs=np.vstack((self.out,newoutputs))
        #retrain neural network

        net = RatBrain(1,200,1)
        optimizer = optim.Adam(net.parameters(), lr=0.0001, betas=(0.9, 0.999))
        criterionValueFunc = nn.MSELoss()

        batch_size=50

        inputsActions = torch.Tensor(inputsActions).float()
        outputs = torch.Tensor(outputs).float()
        inputs=inputsActions

        dataset_train = utils.TensorDataset(inputs, outputs)  # create your datset
        train_loader = utils.DataLoader(dataset_train, batch_size=batch_size, shuffle=True)

        epochs = 500
        losses = []

        for epoch in range(epochs):
            allLoss = 0
            for batch_idx, (data, targetValue) in enumerate(train_loader):
                data, targetValue = Variable(data),Variable(
                    targetValue)
                # resize data from (batch_size, 1, 28, 28) to (batch_size, 28*28)
                optimizer.zero_grad()
                pred_Value= net(data)
                lossValue = criterionValueFunc(targetValue, pred_Value)

                loss = lossValue

                loss.backward()
                optimizer.step()
                allLoss += loss.item()

            losses.append(allLoss)


        self.rat=net
        self.actions=inputsActions.detach().numpy()
        self.out=outputs.detach().numpy()

        newBestParam,bestDA = self.get_bestParam()
        if self.prevBestParams.shape[0] == 0:
            self.prevBestParams = newBestParam.reshape(1, 1)
        else:
            self.prevBestParams = np.vstack((self.prevBestParams, newBestParam.reshape(1, 1)))

        allbestParamOut = self.sim.simulateRat(self.prevBestParams)
        i = np.argmax(np.multiply(allbestParamOut.reshape(-1, ), self.prevBestParams.reshape(-1, )))
        self.BestGuess = self.prevBestParams[i]

    def getRewardModel(self,testParams,testOut,testBestParams):
        #return improvement in model accuracy and best parameter accuracy
        #must maximize MILP with model to get best parameter accuracy
        testParams=torch.Tensor(testParams)
        out_est=self.rat(testParams).detach().numpy()
        out_est=np.array(out_est)

        rModelError = np.sum(np.linalg.norm(out_est - testOut.reshape(-1,1),axis=0)) / np.sqrt(testParams.shape[0])

        bestParam=self.get_bestParam()
        bestParamError=abs(bestParam-testBestParams)
        r=rModelError
        rNew = (self.r_prev - r) / self.r_prev
        self.r_prev = r
        return rNew, r,bestParamError

    def getRewardParam(self,testParams,testOut,testBestParams,testbestDA):
        #return improvement in model accuracy and best parameter accuracy
        #must maximize MILP with model to get best parameter accuracy
        testParams=torch.Tensor(testParams)
        out_est=self.rat(testParams).detach().numpy()
        out_est=np.array(out_est)

        bestParam,bestDA=self.get_bestParam()
        bestParamError=abs(bestParam-testBestParams)
        bestDAError=abs(bestDA-testbestDA)

        if self.prev_error<.001:
            rNew=(self.prev_error-bestParamError)
        else:
            rNew=(self.prev_error-bestParamError)/self.prev_error

        self.prev_reward=rNew
        self.prev_error = bestParamError
        return rNew, bestParamError,bestDAError

    def get_bestParam(self):
        numTest=500
        x = np.atleast_2d(np.linspace(0, 4, numTest)).T
        x = torch.Tensor(x).float()
        y_pred=self.rat(x)
        x=x.detach().numpy()
        if self.optimizeModel==True:
            obj = np.multiply(x,y_pred.detach().numpy())
        else:
            obj=y_pred.detach().numpy()
        i=np.argmax(obj)

        bestParam=x[i]

        bestDA = np.multiply(self.sim.predictRat(bestParam.reshape(-1, 1)), bestParam)

        return bestParam,bestDA




    def reset(self,sim=[]):
        self.states = np.array([])
        self.actions = np.array([])
        self.out=np.array([])

        self.r_prev=1
        self.prev_error=20
        self.rat=[]

        self.prevBestParams = np.array([])
        self.sim=sim


class RatBrain(torch.nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(RatBrain, self).__init__()

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.fc1 = torch.nn.Linear(self.input_size, self.hidden_size)
        self.relu = torch.nn.ReLU()
        self.fc2 = torch.nn.Linear(self.hidden_size, 50)
        self.fc3 = torch.nn.Linear(50, output_size)


    def forward(self, x):

        hidden = self.fc1(x)
        relu = self.relu(hidden)
        output = self.fc2(relu)
        output = self.fc3(self.relu(output))


        return output
