import os
import pickle
from core.buffer import ReplayBuffer
from model import *
from config import Config
from core.random_process import OUNoise
from core.util import hard_update, soft_update, get_class_attr_val
import MILPs.MILPModel_2layer as actorMILP
from RatBrainModel import RatEst
import numpy as np
from matplotlib import pyplot as plt


class DDPG:
    def __init__(self, config: Config):
        self.config = config
        self.init()
        self.rat=RatEst()
        torch.manual_seed(0)

    def init(self):
        self.Z_dim = self.config.Z_dim
        self.action_dim = self.config.action_dim
        self.batch_size = self.config.batch_size
        self.gamma = self.config.gamma
        self.epsilon = self.config.epsilon
        self.is_training = True
        self.randomer = OUNoise(self.action_dim)
        self.buffer = ReplayBuffer(self.config.max_buff)


        self.critic = Critic(self.Z_dim, self.action_dim,h1=self.config.fC2_Hidden,h2=self.config.fC3_Hidden)
        self.critic_target = Critic(self.Z_dim, self.action_dim,h1=self.config.fC2_Hidden,h2=self.config.fC3_Hidden)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), self.config.learning_rate)

        hard_update(self.critic_target, self.critic)

        if torch.cuda.is_available():
            self.cuda()

    def learning(self):
        print("LEARNING")
        s1, a1, r1, t1, s2 = self.buffer.sample_batch(self.batch_size)
        # bool -> int
        t1 = (t1 == False) * 1
        a1 = torch.tensor(a1, dtype=torch.float)
        r1 = torch.tensor(r1, dtype=torch.float)
        t1 = torch.tensor(t1, dtype=torch.float)

        if self.config.use_cuda:
            a1 = a1.cuda()
            r1 = r1.cuda()
            t1=t1.cuda()


        #----------------------------------------------------------------------------------------------------------------------------------------
        #Buffer stores Adamage, Bdamage and associated sigma matrices, ZPrime array representing damage encoding and mean and std of previous Data,
        #current state and last action (u0) and array of previous states and actions
        Z1Prime,prevS_A1,step1=s1
        Z2Prime,prevS_A2,step2 = s2
        # ----------------------------------------------------------------------------------------------------------------------------------------



        #---------------------------------------------------------------------------------------------------------------
        #Critic network is layer to encode state (fC1), layer that takes in encoded state (Z) and action followed by relu,
        # and layer to transform to value (fC3)
        """fC1 = self.critic.fc1.weight.data.cpu().numpy()
        fC1 = fC1.reshape(-1,self.config.fC2_Hidden)
        fB1=self.critic.fc1.bias.data.cpu().numpy()
        """

        fC2= self.critic.fc2.weight.data.cpu().numpy()
        fC2=fC2.reshape(self.config.fC3_Hidden,-1)
        fB2 = self.critic.fc2.bias.data.cpu().numpy()

        fC3 = self.critic.fc3.weight.data.cpu().numpy()
        fC3 = fC3.reshape(self.config.fC3_Hidden)
        fB3 = self.critic.fc3.bias.data.cpu().numpy()

        """fC4 = self.critic.fc4.weight.data.cpu().numpy()
        fC4 = fC4.reshape(self.config.fC4_Hidden)
        fB4 = self.critic.fc4.bias.data.cpu().numpy()
        """
        # ---------------------------------------------------------------------------------------------------------------

        batch_size = len(Z2Prime)

        ##to fix padding save seq lengths in list, then do pack_padded_sequence(paddedS_A2,seq_length,batch_first=True
        paddedS_A1=torch.zeros((batch_size,(self.config.max_steps+2),2))
        paddedS_A2 = torch.zeros((batch_size, (self.config.max_steps+2), 2))
        seq_length1=[]
        seq_length2=[]
        for i in range(batch_size):
            p1=prevS_A1[i].reshape(-1,2)
            paddedS_A1[i,:p1.shape[0],:]=torch.from_numpy(p1).float()
            seq_length1.append(p1.shape[0])

            p2 = prevS_A2[i].reshape(-1,2)
            paddedS_A2[i, :p2.shape[0], :] = torch.from_numpy(p2).float()
            seq_length2.append(p2.shape[0])
        paddedS_A1Converted = torch.nn.utils.rnn.pack_padded_sequence(paddedS_A1, seq_length1, batch_first=True,
                                                             enforce_sorted=False)
        paddedS_A2Converted = torch.nn.utils.rnn.pack_padded_sequence(paddedS_A2, seq_length2, batch_first=True,
                                                             enforce_sorted=False)




        #---------------------------------
        #Get Z array to use in solving MILP
        fakeA = torch.zeros((Z1Prime.shape[0], self.config.action_dim))

        if torch.cuda.is_available():
            fakeA=fakeA.cuda()
            Z2Prime=Z2Prime.cuda()
            Z1Prime = Z1Prime.cuda()
            paddedS_A1Converted.cuda()
            paddedS_A2Converted.cuda()
            step1 = torch.tensor(step1, dtype=torch.float).cuda()
            step2 = torch.tensor(step2, dtype=torch.float).cuda()
        _, Z2 = self.critic(Z2Prime,fakeA,paddedS_A2Converted,step2)


        allA2 = np.zeros((batch_size, self.config.action_dim))

        for i in range(batch_size):
            #------------------------------------------------------------------------------------------
            #Solve MILP to get next action for every step in buffer (all2)
            #for this we need the expected value of the state to adjust objective, Adamage, Bdamage,
            #and associated sigma, encoded state (Z), the critic weights, current state, and last action
            #-------------------------------------------------------------------------------------------
            #value2[i] = self.calcValue(self.critic, Z2Prime[i, :], torch.from_numpy(prevS_A2[i].reshape(1,-1,2)).float(), u02[i])
            Z = Z2[i]
            Z = Z.detach().cpu().numpy().reshape(self.config.fC2_Hidden, -1)

            """try:
                success=success+1
                a = actorMILP.solve_br( Z=Z,fC2=fC2,fC3=fC3,fB2=fB2,fB3=fB3)
                #take only the first two actions because we only care about the value of the first two
                #a_8=a_12[0:8]
            except Exception:
                failure=failure+1
                a_8 = np.zeros((self.config.action_dim))
                """
            a = actorMILP.solve_br(Z=Z, fC2=fC2, fC3=fC3,fB2=fB2,fB3=fB3)


            a = np.asarray(a)
            allA2[i,:]=a

        allA2=torch.tensor(allA2,dtype=torch.float)
        if self.config.use_cuda:
            allA2=allA2.cuda()

        #print('training','Z2',Z2Prime,'a2',allA2,'SA2',paddedS_A2,'Z1',Z1Prime,'a1',a1,'sa1',paddedS_A1,'r1',r1)
        target_q,_ = self.critic_target(Z2Prime, allA2,paddedS_A2Converted,step2)
        target_q=target_q.detach()
        y_expected = r1[:, None] + t1[:, None] * self.config.gamma * target_q
        y_predicted,_ = self.critic.forward(Z1Prime, a1.reshape(-1,1),paddedS_A1Converted,step1)
        print('truth',r1)
        print('yexpected',y_expected)
        print('predicted',y_predicted)
        print('actions',a1)
        print("PREDICTED",y_predicted)
        print("YEXP",y_expected)
        print("R1",r1)
        print('step1',step1)
        print('step2',step2)

        # critic gradient
        critic_loss = nn.MSELoss()
        loss_critic = critic_loss(y_predicted, y_expected)
        self.critic_optimizer.zero_grad()
        loss_critic.backward()
        self.critic_optimizer.step()

        # Notice that we only have gradient updates for actor and critic, not target
        # actor_optimizer.step() and critic_optimizer.step()

        soft_update(self.critic_target, self.critic, self.config.tau)

        return loss_critic.item()

    def calcValue(self,critic,state,prevSA,u0):
        numA=100
        elevator1= np.random.uniform(u0[0]-5,u0[0]+5,(1, numA))
        thrust1 = np.random.uniform(u0[1]-100,u0[1]+100,(1, numA))
        aileron1 = np.random.uniform(u0[2]-5,u0[2]+5,(1, numA))
        rudder1 = np.random.uniform(u0[3]-5,u0[3]+5,(1, numA))

        elevator2 = np.random.uniform(u0[0]-5,u0[0]+5,(1, numA))
        thrust2 = np.random.uniform(u0[1]-100,u0[1]+100, (1, numA))
        aileron2 = np.random.uniform(u0[2]-5,u0[2]+5,(1, numA))
        rudder2 = np.random.uniform(u0[3]-5,u0[3]+5,(1, numA))

        randomActions=np.vstack((elevator1, thrust1, aileron1, rudder1,elevator2, thrust2, aileron2, rudder2))
        randomActions=randomActions.transpose(1,0)
        randomActions=torch.from_numpy(randomActions).float()

        states=state.repeat(numA,1)
        prevSAs=prevSA.repeat(numA,1,1)
        #prevSAs = torch.nn.utils.rnn.pack_padded_sequence(prevSAs, [prevSAs.shape[1]]*numA, batch_first=True,
        #                                                     enforce_sorted=False)

        values,_=critic.forward(states.cuda(),randomActions.cuda(),prevSAs.cuda())
        return np.mean(values.detach().cpu().numpy())


    def cuda(self):
        self.critic.cuda()
        self.critic_target.cuda()

    def decay_epsilon(self):
        self.epsilon -= self.config.eps_decay

    def get_action(self, state):
        #state = torch.tensor(state, dtype=torch.float).unsqueeze(0)
        ZPrime,prevS_A,step=state

        prevS_A=torch.from_numpy(prevS_A.reshape(1,-1,2)).float()
        ZPrime = torch.tensor(ZPrime, dtype=torch.float)
        step = torch.tensor(step, dtype=torch.float).reshape(1,1).cuda()
        #_, Z=self.critic(ZPrime,torch.zeros((ZPrime.shape[0],8)),prevS_A)

        #value=self.calcValue(self.critic,ZPrime,prevS_A,u0)

        fC2= self.critic.fc2.weight.data.cpu().numpy()
        fC2=fC2.reshape(self.config.fC3_Hidden,-1)
        fB2 = self.critic.fc2.bias.data.cpu().numpy()

        fC3 = self.critic.fc3.weight.data.cpu().numpy()
        fC3 = fC3.reshape(self.config.fC3_Hidden)
        fB3 = self.critic.fc3.bias.data.cpu().numpy()

        """fC4 = self.critic.fc4.weight.data.cpu().numpy()
        fC4 = fC4.reshape(self.config.fC4_Hidden)
        fB4 = self.critic.fc4.bias.data.cpu().numpy()
        """
        print(prevS_A)
        print(len(prevS_A))
        paddedS_A1Converted = torch.nn.utils.rnn.pack_padded_sequence(prevS_A, [len(prevS_A)], batch_first=True,
                                                                      enforce_sorted=False)
        _, Z2 = self.critic(ZPrime.cuda(), torch.zeros((ZPrime.shape[0], 1)).cuda(),paddedS_A1Converted.cuda(),step,batch_size=1)
        Z = Z2.detach().cpu().numpy().reshape(self.config.fC2_Hidden, -1)



        action = actorMILP.solve_br(Z=Z,fC2=fC2,fC3=fC3,fB2=fB2,fB3=fB3,value=0)
        action=np.asarray(action).reshape(-1,1)

        #correctOut, _ = self.critic(ZPrime.cuda(), torch.FloatTensor(action).cuda().reshape(1, 1), prevS_A.cuda())

        """plt.figure()
        plot_action=np.atleast_2d(np.linspace(0, 4, 1000)).T
        plot_action=torch.Tensor(plot_action).float()
        ZPrimeTiled=np.tile(ZPrime,(1000,1))
        prevS_ATile=np.tile(prevS_A,(1000,1,1))
        plot_out,_= self.critic(torch.Tensor(ZPrimeTiled).cuda(),plot_action.cuda(),torch.Tensor(prevS_ATile).cuda())
        plot_action=plot_action.reshape(-1,1).detach().numpy()
        plot_out=plot_out.reshape(-1,1).cpu().detach().numpy()
        plt.plot(plot_action,plot_out)
        plt.show()
        """


        #print('correctOut', correctOut)

        action=action.astype('float')
        #noise=self.is_training*max(self.epsilon,self.config.epsilon_min)*self.randomer.noise()
        np.random.seed(seed=None)
        noise=self.is_training*max(self.epsilon,self.config.epsilon_min)*np.random.normal()
        #print('noise',noise)
        """if action<.01:
            noise=abs(noise)
        elif action>3.9:
            if noise>0:
                noise=noise*-1
                """

        #print('noise', noise)
        action+=noise
        action=np.clip(action,0,4)
        print('action after noise',action)
        return action

    def reset(self):
        self.randomer.reset()

    def load_weights(self, output):
        if output is None: return
        #self.actor.load_state_dict(torch.load('{}/actor.pkl'.format(output)))
        self.critic.load_state_dict(torch.load('{}/critic.pkl'.format(output)))

    def save_model(self, output):
        #torch.save(self.actor.state_dict(), '{}/actor.pkl'.format(output))
        torch.save(self.critic.state_dict(), '{}/critic.pkl'.format(output))

    def save_config(self, output, save_obj=False):

        with open(output + '/config.txt', 'w') as f:
            attr_val = get_class_attr_val(self.config)
            for k, v in attr_val.items():
                f.write(str(k) + " = " + str(v) + "\n")

        if save_obj:
            file = open(output + '/config.obj', 'wb')
            pickle.dump(self.config, file)
            file.close()

    def save_checkpoint(self, ep, total_step, output):
        print("SAVED")
        checkpath = output + '/checkpoint_model'
        os.makedirs(checkpath, exist_ok=True)

        torch.save({
            'episodes': ep,
            'total_step': total_step,
            'critic': self.critic.state_dict()
        }, '%s/checkpoint_ep_%d.tar'% (checkpath, ep))


    def load_checkpoint(self, model_path):
        checkpoint = torch.load(model_path)
        episode = checkpoint['episodes']
        total_step = checkpoint['total_step']
        #self.actor.load_state_dict(checkpoint['actor'])
        self.critic.load_state_dict(checkpoint['critic'])

        return episode, total_step





