# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
from model_A2C import model
from Obs_model import obs_model, combine_model, time_model
from memory import memory
import torch
import torch.nn as nn
import random
import math
import numpy as np
import torch.nn.functional as F
import copy


class high_agent(object):
    def __init__(self, para):
        self.state_dim = para.state_dim
        self.action_dim = para.action_dim
        self.embed_dim = para.embed_dim
        self.mltpro_num = para.mltpro_num

        self.memory = memory(para.mem_size, state_dim=self.state_dim, action_dim=self.action_dim,
                              CNN_FLAG=True, mltpro_flag=True, mlt_num=self.mltpro_num)
        self.memory.reset('all')

        self.lr = para.learn_rate
        self.gamma = para.gamma
        self.device = torch.device(para.device)
        self.entropy_rate = 0.8

        self.lan_flag = True
        self.lan_vec_size = para.goal_dim
        self.build_net()

        self.agent_low = None
        self.action_vec = np.array([[1, 1, 0, 0, 0, 0, 0, 0],
                            [0, 0, 1, 1, 0, 0, 0, 0],
                            [0, 0, 0, 0, 1, 1, 0, 0],
                            [0, 0, 0, 0, 0, 0, 1, 1]])


    def build_net(self):
        ############################ build shared net ##################
        self.obs = obs_model(self.embed_dim)
        self.obs_optimizer = torch.optim.RMSprop(self.obs.parameters(), self.lr)
        #############################  build net ###################

        self.combL = combine_model(self.embed_dim)
        self.combL_optimizer = torch.optim.RMSprop(self.combL.parameters(), self.lr)

        self.LSTML = time_model(self.embed_dim, self.embed_dim)
        self.LSTML_optimizer = torch.optim.RMSprop(self.LSTML.parameters(), self.lr)

        self.ins2emb = nn.Linear(self.lan_vec_size, self.embed_dim)
        self.ins2emb_optimizer = torch.optim.RMSprop(self.ins2emb.parameters(), self.lr)

        self.actorL = model(self.embed_dim, self.action_dim, self.embed_dim).actor
        self.criticL = model(self.embed_dim, self.action_dim, self.embed_dim).critic
        self.actorL_optimizer = torch.optim.RMSprop(self.actorL.parameters(), self.lr)
        self.criticL_optimizer = torch.optim.RMSprop(self.criticL.parameters(), self.lr)

        self.loss_function_c = nn.MSELoss()

    def net2cuda(self, device = torch.device('cuda:0')):

        self.obs.cuda(device=device)
        self.combL.cuda(device=device)
        self.LSTML.cuda(device=device)
        self.ins2emb.cuda(device=device)
        self.actorL.cuda(device=device)
        self.criticL.cuda(device=device)

        if self.agent_low is not None:
            self.agent_low.to_cuda(device)

    def select_action_low(self, act_high, state_pos):
        state_low = np.concatenate((state_pos, self.action_vec[act_high]), axis=-1)
        state_low = torch.from_numpy(state_low).cuda().float()
        act_low = self.agent_low.exploit(state_low)
        return act_low


    def act_forSBG(self, s, goal=None, graph=False):
        assert s.any() >= 0

        if type(s) != 'numpy.ndarray':
            s = np.array(s)
            s = torch.from_numpy(s)
        else:
            s = torch.from_numpy(s)

        s_cuda = s.cuda().float()
        s_cuda.requires_grad = False

        if goal.any() != None:
            goal = torch.from_numpy(goal).cuda().float()
            goal.requires_grad = False
            if graph:
                goal_cuda = self.obs(goal)
            else:
                goal_cuda = goal
            self.obs_sL = self.obs(s_cuda)
            self.embed_aL = self.combL(self.LSTML(self.obs_sL), self.ins2emb(goal_cuda.unsqueeze(0)))

        self.actL = self.actorL(self.embed_aL).detach()
        self.act_disL = F.softmax(self.actL, dim=1)[0]
        assert (sum(self.act_disL) - 1) < 1e-5
        self.aL = torch.multinomial(self.act_disL, 1)
        self.aL = self.aL.data.cpu().numpy().tolist()[0]

        return self.aL



    def train_mlt(self, goal_=None, pro_num_=None, calc_num = 1):
        assert pro_num_ != None
        self.eposide = pro_num_
        self.obs_optimizer.zero_grad()
        self.combL_optimizer.zero_grad()
        self.ins2emb_optimizer.zero_grad()
        self.LSTML_optimizer.zero_grad()
        self.criticL_optimizer.zero_grad()
        self.actorL_optimizer.zero_grad()
        self.loss_a = 0.
        self.loss_c = 0.
        self.loss_e = 0.
        for pro_num in range(pro_num_):
            if goal_[pro_num] == 'failed':
                self.eposide = self.eposide - 1
                continue
            goal = goal_[pro_num]
            flag, s, a, r, s_1 = self.memory.get_batch(pro_num)
            if not flag:
                continue

            ep_num = self.memory.size_now[pro_num]

            a = copy.deepcopy(a[::-1])  # inverse sequence
            s = copy.deepcopy(s[::-1])
            r = copy.deepcopy(r[::-1])
            a = torch.from_numpy(a.astype(np.int64)).cuda().unsqueeze(1)
            if type(s) != 'numpy.ndarray':
                s = np.array(s)
                s = torch.from_numpy(s).squeeze()
            else:
                s = torch.from_numpy(s)


            if self.lan_flag:

                if type(goal) != 'numpy.ndarray':
                    goal = np.array(goal)
                goal = torch.from_numpy(goal)
                goal_cuda = goal.cuda().float()
                self.switch_sbg = True

            s_cuda = s.cuda().float()
            r_cuda = torch.from_numpy(r).cuda().float()
            self.LSTML.hidden_reset()
            if s_cuda.shape[1] != 3:
                s_cuda = s_cuda.unsqueeze(0)

            R = torch.zeros(1, 1).cuda().float()

            for state, act, reward, g in zip(s_cuda, a, r_cuda, goal_cuda):

                R = self.gamma * R + reward
                R = R.detach()
                self.input = state.unsqueeze(0)

                self.embeding_v = self.obs(self.input, mlt =True)
                self.input = self.embeding_v
                self.embeding_v_t = self.LSTML(self.input)
                self.input = self.embeding_v_t
                self.goal = g
                self.temp = self.combL(self.input, self.ins2emb(self.goal.unsqueeze(0)))
                self.input = self.temp


                V_now = self.criticL(self.input)
                self.loss_c += self.loss_function_c(V_now, R) / ep_num
                Advan = R - V_now
                self.temp_a = self.actorL(self.input).squeeze()
                assert (sum(F.softmax(self.temp_a)) - 1) < 1e-5
                self.loss_a += -Advan.detach() * F.log_softmax(self.temp_a)[act] / ep_num
                self.entropy = -torch.sum(F.softmax(self.temp_a)[act] * F.log_softmax(self.temp_a)[act]) / ep_num
                self.loss_e -= self.entropy / ep_num

        if self.eposide <= 0:
            return
        ((self.loss_c + self.loss_a + self.entropy_rate * self.loss_e) * calc_num / self.eposide).backward()
        print('loss_low:', self.loss_c + self.loss_a + self.entropy_rate * self.loss_e)

        self.criticL_optimizer.step()
        self.actorL_optimizer.step()
        self.combL_optimizer.step()
        self.ins2emb_optimizer.step()
        self.LSTML_optimizer.step()
        self.obs_optimizer.step()
        self.memory.reset('all')




    def save_model(self, path):
        torch.save(self.obs, path + '_obs_model')
        torch.save(self.LSTML, path + '_lstmL_model')
        torch.save(self.ins2emb, path + '_ins2emb_model')
        torch.save(self.combL, path + '_combL_model')
        torch.save(self.actorL, path + '_actorL_model')
        torch.save(self.criticL, path + '_criticL_model')

    def load_model_frompath(self, path):
        self.obs = torch.load(path + '_obs_model')
        self.LSTML = torch.load(path + '_lstmL_model')
        self.ins2emb = torch.load(path + '_ins2emb_model')
        self.combL = torch.load(path + '_combL_model')
        self.actorL = torch.load(path + '_actorL_model')
        self.criticL = torch.load(path + '_criticL_model')


        self.obs.cpu()
        self.LSTML.cpu()
        self.ins2emb.cpu()
        self.combL.cpu()
        self.actorL.cpu()
        self.criticL.cpu()


    def load_model_frommodel(self, parameters):
        self.obs.load_state_dict(parameters[0])
        self.LSTML.load_state_dict(parameters[1])
        self.ins2emb.load_state_dict(parameters[2])
        self.combL.load_state_dict(parameters[3])
        self.actorL.load_state_dict(parameters[4])
        self.criticL.load_state_dict(parameters[5])


    def get_curr_para(self):
        return [self.obs.state_dict(),
                self.LSTML.state_dict(),
                self.ins2emb.state_dict(),
                self.combL.state_dict(), self.actorL.state_dict(), self.criticL.state_dict()]

