import os
import pickle
from core.buffer import ReplayBuffer
from model import *
from config import Config
from core.random_process import OUNoise
from core.util import hard_update, soft_update, get_class_attr_val
import MILPs.MILPModel_2layer as actorMILP
from RatBrainModel import RatEst
import numpy as np
from matplotlib import pyplot as plt


class DDPG:
    def __init__(self, config: Config):
        self.config = config
        self.init()
        self.rat=RatEst()
        torch.manual_seed(0)

    def init(self):
        self.Z_dim = self.config.Z_dim
        self.action_dim = self.config.action_dim
        self.batch_size = self.config.batch_size
        self.gamma = self.config.gamma
        self.epsilon = self.config.epsilon
        self.is_training = True
        self.randomer = OUNoise(self.action_dim)
        self.buffer = ReplayBuffer(self.config.max_buff)


        self.critic = Critic(self.Z_dim, self.action_dim,h1=self.config.fC2_Hidden,h2=self.config.fC3_Hidden)
        self.critic_target = Critic(self.Z_dim, self.action_dim,h1=self.config.fC2_Hidden,h2=self.config.fC3_Hidden)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), self.config.learning_rate)
        self.rat_optimizer = torch.optim.Adam(self.critic.parameters(), self.config.learning_rate)

        hard_update(self.critic_target, self.critic)

        if torch.cuda.is_available():
            self.cuda()

    def learning(self):
        print("LEARNING Q")
        s1, a1, r1, t1, s2 = self.buffer.sample_batch(self.batch_size)

        # bool -> int
        t1 = (t1 == False) * 1
        a1 = torch.tensor(a1, dtype=torch.float)
        r1 = torch.tensor(r1, dtype=torch.float)
        t1 = torch.tensor(t1, dtype=torch.float)

        if self.config.use_cuda:
            a1 = a1.cuda()
            r1 = r1.cuda()
            t1=t1.cuda()


        #----------------------------------------------------------------------------------------------------------------------------------------
        #Buffer stores Adamage, Bdamage and associated sigma matrices, ZPrime array representing damage encoding and mean and std of previous Data,
        #current state and last action (u0) and array of previous states and actions
        Z1Prime,prevS_A1,bestPGuess1,trueRatNums=s1
        Z2Prime,prevS_A2,bestPGuess2,_ = s2
        # ----------------------------------------------------------------------------------------------------------------------------------------



        #---------------------------------------------------------------------------------------------------------------
        #Critic network is layer to encode state (fC1), layer that takes in encoded state (Z) and action followed by relu,
        # and layer to transform to value (fC3)

        fC2= self.critic.fc2.weight.data.cpu().numpy()
        fC2=fC2.reshape(self.config.fC3_Hidden,-1)
        fB2 = self.critic.fc2.bias.data.cpu().numpy()

        fC3 = self.critic.fc3.weight.data.cpu().numpy()
        fC3 = fC3.reshape(self.config.fC3_Hidden)
        fB3 = self.critic.fc3.bias.data.cpu().numpy()

        # ---------------------------------------------------------------------------------------------------------------

        batch_size = len(Z2Prime)

        ##to fix padding save seq lengths in list, then do pack_padded_sequence(paddedS_A2,seq_length,batch_first=True

        paddedS_A1Converted=prevS_A1
        paddedS_A2Converted=prevS_A2

        #---------------------------------
        #Get Z array to use in solving MILP
        fakeA = torch.zeros((Z1Prime.shape[0], self.config.action_dim))

        if torch.cuda.is_available():
            fakeA=fakeA.cuda()
            Z2Prime=Z2Prime.cuda()
            Z1Prime = Z1Prime.cuda()
        _, Z2,_ = self.critic(Z2Prime,fakeA,paddedS_A2Converted,bestPGuess2)


        allA2 = np.zeros((batch_size, self.config.action_dim))

        for i in range(batch_size):
            #------------------------------------------------------------------------------------------
            #Solve MILP to get next action for every step in buffer (all2)
            #for this we need the expected value of the state to adjust objective, Adamage, Bdamage,
            #and associated sigma, encoded state (Z), the critic weights, current state, and last action
            #-------------------------------------------------------------------------------------------
            Z = Z2[i]
            Z = Z.detach().cpu().numpy().reshape(self.config.fC2_Hidden, -1)

            a = actorMILP.solve_br(Z=Z, fC2=fC2, fC3=fC3,fB2=fB2,fB3=fB3)
            a = np.asarray(a)
            allA2[i,:]=a

        allA2=torch.tensor(allA2,dtype=torch.float)
        if self.config.use_cuda:
            allA2=allA2.cuda()

        target_q,_,_ = self.critic_target(Z2Prime, allA2,paddedS_A2Converted,bestPGuess2)
        target_q=target_q.detach()
        y_expected = r1[:, None] + t1[:, None] * self.config.gamma * target_q
        y_predicted,_,ratNums = self.critic.forward(Z1Prime, a1.reshape(-1,1),paddedS_A1Converted,bestPGuess1)

        # critic gradient
        critic_loss = nn.MSELoss()
        loss_critic = critic_loss(y_predicted, y_expected)
        self.critic_optimizer.zero_grad()
        loss_critic.backward()
        self.critic_optimizer.step()

        ratLossFunc=nn.CrossEntropyLoss()
        loss_rat = ratLossFunc(ratNums, trueRatNums.reshape(-1).long().cuda())
        self.rat_optimizer.zero_grad()
        loss_rat.backward()
        self.rat_optimizer.step()

        soft_update(self.critic_target, self.critic, self.config.tau)

        return loss_critic.item()

    def calcValue(self,critic,state,prevSA,u0):
        numA=100
        elevator1= np.random.uniform(u0[0]-5,u0[0]+5,(1, numA))
        thrust1 = np.random.uniform(u0[1]-100,u0[1]+100,(1, numA))
        aileron1 = np.random.uniform(u0[2]-5,u0[2]+5,(1, numA))
        rudder1 = np.random.uniform(u0[3]-5,u0[3]+5,(1, numA))

        elevator2 = np.random.uniform(u0[0]-5,u0[0]+5,(1, numA))
        thrust2 = np.random.uniform(u0[1]-100,u0[1]+100, (1, numA))
        aileron2 = np.random.uniform(u0[2]-5,u0[2]+5,(1, numA))
        rudder2 = np.random.uniform(u0[3]-5,u0[3]+5,(1, numA))

        randomActions=np.vstack((elevator1, thrust1, aileron1, rudder1,elevator2, thrust2, aileron2, rudder2))
        randomActions=randomActions.transpose(1,0)
        randomActions=torch.from_numpy(randomActions).float()

        states=state.repeat(numA,1)
        prevSAs=prevSA.repeat(numA,1,1)

        values,_=critic.forward(states.cuda(),randomActions.cuda(),prevSAs.cuda())
        return np.mean(values.detach().cpu().numpy())


    def cuda(self):
        self.critic.cuda()
        self.critic_target.cuda()

    def decay_epsilon(self):
        self.epsilon -= self.config.eps_decay

    def get_action(self, state):
        ZPrime,prevS_A,bestPGuess,_=state
        prevS_A=prevS_A.reshape(1,-1,self.config.LSTM_in_size).float()
        ZPrime = torch.tensor(ZPrime, dtype=torch.float)


        fC2= self.critic.fc2.weight.data.cpu().numpy()
        fC2=fC2.reshape(self.config.fC3_Hidden,-1)
        fB2 = self.critic.fc2.bias.data.cpu().numpy()

        fC3 = self.critic.fc3.weight.data.cpu().numpy()
        fC3 = fC3.reshape(self.config.fC3_Hidden)
        fB3 = self.critic.fc3.bias.data.cpu().numpy()

        paddedS_A1Converted=prevS_A

        _, Z2,_ = self.critic(ZPrime.cuda(), torch.zeros((ZPrime.shape[0], 1)).cuda(),[paddedS_A1Converted.cuda()],bestPGuess,batch_size=1)
        Z = Z2.detach().cpu().numpy().reshape(self.config.fC2_Hidden, -1)

        action = actorMILP.solve_br(Z=Z,fC2=fC2,fC3=fC3,fB2=fB2,fB3=fB3,value=0)
        action=np.asarray(action).reshape(-1,1)

        action=action.astype('float')

        np.random.seed(seed=None)
        noise=self.is_training*max(self.epsilon,self.config.epsilon_min)*np.random.normal(scale=2.0)

        if action<.01 and noise<0:
            noise=abs(noise)/5
        elif action>3.9:
            if noise>0:
                noise=(noise*-1)/5



        noise=np.clip(noise,-2,2)
        action+=noise
        action=np.clip(action,0,4)

        return action

    def reset(self):
        self.randomer.reset()

    def load_weights(self, output):
        if output is None: return
        self.critic.load_state_dict(torch.load('{}/critic.pkl'.format(output)))

    def save_model(self, output,name='{}/critic.pkl'):
        torch.save(self.critic.state_dict(), name.format(output))

    def save_config(self, output, save_obj=False):

        with open(output + '/config.txt', 'w') as f:
            attr_val = get_class_attr_val(self.config)
            for k, v in attr_val.items():
                f.write(str(k) + " = " + str(v) + "\n")

        if save_obj:
            file = open(output + '/config.obj', 'wb')
            pickle.dump(self.config, file)
            file.close()

    def save_checkpoint(self, ep, total_step, output):
        print("SAVED")
        checkpath = output + '/checkpoint_model'
        os.makedirs(checkpath, exist_ok=True)

        torch.save({
            'episodes': ep,
            'total_step': total_step,
            'critic': self.critic.state_dict()
        }, '%s/checkpoint_ep_%d.tar'% (checkpath, ep))


    def load_checkpoint(self, model_path):
        checkpoint = torch.load(model_path)
        episode = checkpoint['episodes']
        total_step = checkpoint['total_step']
        #self.actor.load_state_dict(checkpoint['actor'])
        self.critic.load_state_dict(checkpoint['critic'])

        return episode, total_step










