import os
import pickle
from core.buffer import ReplayBuffer
from model import *
from config import Config
from core.random_process import OUNoise
from core.util import hard_update, soft_update, get_class_attr_val
import MILPs.MILPModel as actorMILP
from DamageModel import Damage
import numpy as np


class DDPG:
    def __init__(self, config: Config):
        self.config = config
        self.init()
        self.d=Damage()
        torch.manual_seed(0)

    def init(self):
        self.Z_dim = self.config.Z_dim
        self.action_dim = self.config.action_dim
        self.batch_size = self.config.batch_size
        self.gamma = self.config.gamma
        self.epsilon = self.config.epsilon
        self.is_training = True
        self.randomer = OUNoise(self.action_dim+4)
        self.buffer = ReplayBuffer(self.config.max_buff)


        self.critic = Critic(self.Z_dim, self.action_dim,h1=self.config.fC2_Hidden,h2=self.config.fC3_Hidden)
        self.critic_target = Critic(self.Z_dim, self.action_dim,h1=self.config.fC2_Hidden,h2=self.config.fC3_Hidden)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), self.config.learning_rate)

        hard_update(self.critic_target, self.critic)

        if torch.cuda.is_available():
            self.cuda()

    def learning(self):
        print("LEARNING")
        s1, a1, r1, t1, s2 = self.buffer.sample_batch(self.batch_size)
        # bool -> int
        t1 = (t1 == False) * 1

        a1 = torch.tensor(a1, dtype=torch.float)
        # take only the first two actions because we only care about the value of the first two
        a1=a1[:,0:8]
        r1 = torch.tensor(r1, dtype=torch.float)
        t1 = torch.tensor(t1, dtype=torch.float)

        if self.config.use_cuda:
            a1 = a1.cuda()
            r1 = r1.cuda()
            t1=t1.cuda()


        #----------------------------------------------------------------------------------------------------------------------------------------
        #Buffer stores Adamage, Bdamage and associated sigma matrices, ZPrime array representing damage encoding and mean and std of previous Data,
        #current state and last action (u0) and array of previous states and actions
        A1, B1, ASig1,BSig1,Z1Prime,currentState1,u01,prevS_A1=s1
        A2, B2, ASig2,BSig2,Z2Prime,currentState2,u02,prevS_A2 = s2
        # ----------------------------------------------------------------------------------------------------------------------------------------



        #---------------------------------------------------------------------------------------------------------------
        #Critic network is layer to encode state (fC1), layer that takes in encoded state (Z) and action followed by relu,
        # and layer to transform to value (fC3)
        fC1 = self.critic.fc1.weight.data.cpu().numpy()
        fC1 = fC1.reshape(-1,self.config.fC2_Hidden)

        fC2= self.critic.fc2.weight.data.cpu().numpy()
        fC2=fC2.reshape(self.config.fC3_Hidden,-1)
        fB2 = self.critic.fc2.bias.data.cpu().numpy()

        fC3 = self.critic.fc3.weight.data.cpu().numpy()
        fC3 = fC3.reshape(self.config.fC3_Hidden)
        fB3 = self.critic.fc3.bias.data.cpu().numpy()
        # ---------------------------------------------------------------------------------------------------------------

        batch_size = len(currentState2)

        ##to fix padding save seq lengths in list, then do pack_padded_sequence(paddedS_A2,seq_length,batch_first=True
        """paddedS_A1=torch.zeros((batch_size,(self.config.max_steps+1)*3,15))
        paddedS_A2 = torch.zeros((batch_size, (self.config.max_steps+1)*3, 15))
        seq_length1=[]
        seq_length2=[]
        for i in range(batch_size):
            p1=prevS_A1[i].reshape(-1,15)
            paddedS_A1[i,:p1.shape[0],:]=torch.from_numpy(p1).float()
            seq_length1.append(p1.shape[0])

            p2 = prevS_A2[i].reshape(-1,15)
            paddedS_A2[i, :p2.shape[0], :] = torch.from_numpy(p2).float()
            seq_length2.append(p2.shape[0])
        """
        #paddedS_A1Converted = torch.nn.utils.rnn.pack_padded_sequence(paddedS_A1, seq_length1, batch_first=True,
        #                                                     enforce_sorted=False)
        #paddedS_A2Converted = torch.nn.utils.rnn.pack_padded_sequence(paddedS_A2, seq_length2, batch_first=True,
        #                                                     enforce_sorted=False)

        paddedS_A1 = prevS_A1
        paddedS_A2 = prevS_A2

        #---------------------------------
        #Get Z array to use in solving MILP
        fakeA = torch.zeros((Z1Prime.shape[0], self.config.action_dim))
        if torch.cuda.is_available():
            fakeA=fakeA.cuda()
            Z2Prime=Z2Prime.cuda()
            Z1Prime = Z1Prime.cuda()
            #paddedS_A1.cuda()
            #paddedS_A2.cuda()
        _, Z2 = self.critic(Z2Prime,fakeA,paddedS_A2)


        value2 = np.zeros(batch_size)
        allA2 = np.zeros((batch_size, self.config.action_dim))
        success=0
        failure=0
        for i in range(batch_size):
            #------------------------------------------------------------------------------------------
            #Solve MILP to get next action for every step in buffer (all2)
            #for this we need the expected value of the state to adjust objective, Adamage, Bdamage,
            #and associated sigma, encoded state (Z), the critic weights, current state, and last action
            #-------------------------------------------------------------------------------------------
            #value2[i] = self.calcValue(self.critic, Z2Prime[i, :], torch.from_numpy(prevS_A2[i].reshape(1,-1,15)).float(), u02[i])
            A=A2[i]
            B=B2[i]
            AS=ASig2[i]
            BS=BSig2[i]
            Z=Z2[i]
            currS=currentState2[i]
            currU = u02[i]
            Z = Z.detach().cpu().numpy().reshape(self.config.fC2_Hidden, -1)

            try:
                success=success+1
                a_12,prob = actorMILP.solve_br(A, B, AS, BS, Z=Z,fC2=fC2,fC3=fC3,fB2=fB2,fB3=fB3,xk=currS,u0=currU,value=value2[i])
                #take only the first two actions because we only care about the value of the first two
                a_8=a_12[0:8]
            except Exception:
                failure=failure+1
                a_8 = np.zeros((self.config.action_dim))


            a_8 = np.asarray(a_8)
            allA2[i,:]=a_8

        #print("VAL", value2)
        print("success",success)
        print("failure",failure)



        allA2=torch.tensor(allA2,dtype=torch.float)
        if self.config.use_cuda:
            allA2=allA2.cuda()


        target_q,_ = self.critic_target(Z2Prime, allA2,paddedS_A2)
        target_q=target_q.detach()
        y_expected = r1[:, None] + t1[:, None] * self.config.gamma * target_q
        y_predicted,_ = self.critic.forward(Z1Prime, a1,paddedS_A1)
        #print("PREDICTED",y_predicted)
        #print("YEXP",y_expected)
        #print("R1",r1)

        # critic gradient
        critic_loss = nn.MSELoss()
        loss_critic = critic_loss(y_predicted, y_expected)
        self.critic_optimizer.zero_grad()
        loss_critic.backward()
        self.critic_optimizer.step()

        # Notice that we only have gradient updates for actor and critic, not target
        # actor_optimizer.step() and critic_optimizer.step()

        soft_update(self.critic_target, self.critic, self.config.tau)

        return loss_critic.item()

    def calcValue(self,critic,state,prevSA,u0):
        numA=100
        elevator1= np.random.uniform(u0[0]-5,u0[0]+5,(1, numA))
        thrust1 = np.random.uniform(u0[1]-100,u0[1]+100,(1, numA))
        aileron1 = np.random.uniform(u0[2]-5,u0[2]+5,(1, numA))
        rudder1 = np.random.uniform(u0[3]-5,u0[3]+5,(1, numA))

        elevator2 = np.random.uniform(u0[0]-5,u0[0]+5,(1, numA))
        thrust2 = np.random.uniform(u0[1]-100,u0[1]+100, (1, numA))
        aileron2 = np.random.uniform(u0[2]-5,u0[2]+5,(1, numA))
        rudder2 = np.random.uniform(u0[3]-5,u0[3]+5,(1, numA))

        randomActions=np.vstack((elevator1, thrust1, aileron1, rudder1,elevator2, thrust2, aileron2, rudder2))
        randomActions=randomActions.transpose(1,0)
        randomActions=torch.from_numpy(randomActions).float()

        states=state.repeat(numA,1)
        prevSAs=prevSA.repeat(numA,1,1)
        #prevSAs = torch.nn.utils.rnn.pack_padded_sequence(prevSAs, [prevSAs.shape[1]]*numA, batch_first=True,
        #                                                     enforce_sorted=False)

        values,_=critic.forward(states.cuda(),randomActions.cuda(),prevSAs.cuda())
        return np.mean(values.detach().cpu().numpy())


    def cuda(self):
        self.critic.cuda()
        self.critic_target.cuda()

    def decay_epsilon(self):
        self.epsilon -= self.config.eps_decay

    def get_action(self, state):
        #state = torch.tensor(state, dtype=torch.float).unsqueeze(0)
        A,B,AS,BS,ZPrime,currState,u0,prevS_A=state

        #prevS_A=torch.from_numpy(prevS_A.reshape(1,-1,15)).float()
        ZPrime = torch.tensor(ZPrime, dtype=torch.float)
        #_, Z=self.critic(ZPrime,torch.zeros((ZPrime.shape[0],8)),prevS_A)

        #value=self.calcValue(self.critic,ZPrime,prevS_A,u0)
        value=0

        fC2= self.critic.fc2.weight.data.cpu().numpy()
        fC2=fC2.reshape(self.config.fC3_Hidden,-1)
        fB2 = self.critic.fc2.bias.data.cpu().numpy()

        fC3 = self.critic.fc3.weight.data.cpu().numpy()
        fC3 = fC3.reshape(self.config.fC3_Hidden)
        fB3 = self.critic.fc3.bias.data.cpu().numpy()

        #prevS_A = torch.nn.utils.rnn.pack_padded_sequence(prevS_A, [prevS_A.shape[1]], batch_first=True,
        #                                                     enforce_sorted=False)
        _, Z2 = self.critic(ZPrime.cuda(), torch.zeros((ZPrime.shape[0], 8)).cuda(),[prevS_A])
        Z = Z2.detach().cpu().numpy().reshape(self.config.fC2_Hidden, -1)


        action,prob = actorMILP.solve_br(A, B, AS, BS,Z=Z,fC2=fC2,fC3=fC3,fB2=fB2,fB3=fB3,xk=currState,u0=u0,value=value)
        action=np.asarray(action)
        action+=self.is_training*max(self.epsilon,self.config.epsilon_min)*self.randomer.noise()
        return action

    def reset(self):
        self.randomer.reset()

    def load_weights(self, output):
        if output is None: return
        #self.actor.load_state_dict(torch.load('{}/actor.pkl'.format(output)))
        self.critic.load_state_dict(torch.load('{}/critic.pkl'.format(output)))

    def save_model(self, output,name='{}/critic.pkl'):
        #torch.save(self.actor.state_dict(), '{}/actor.pkl'.format(output))
        torch.save(self.critic.state_dict(), name.format(output))

    def save_config(self, output, save_obj=False):

        with open(output + '/config.txt', 'w') as f:
            attr_val = get_class_attr_val(self.config)
            for k, v in attr_val.items():
                f.write(str(k) + " = " + str(v) + "\n")

        if save_obj:
            file = open(output + '/config.obj', 'wb')
            pickle.dump(self.config, file)
            file.close()

    def save_checkpoint(self, ep, total_step, output):
        print("SAVED")
        checkpath = output + '/checkpoint_model'
        os.makedirs(checkpath, exist_ok=True)

        torch.save({
            'episodes': ep,
            'total_step': total_step,
            'critic': self.critic.state_dict()
        }, '%s/checkpoint_ep_%d.tar'% (checkpath, ep))


    def load_checkpoint(self, model_path):
        checkpoint = torch.load(model_path)
        episode = checkpoint['episodes']
        total_step = checkpoint['total_step']
        #self.actor.load_state_dict(checkpoint['actor'])
        self.critic.load_state_dict(checkpoint['critic'])

        return episode, total_step










