from config import *
from utilis import *
from models.DC2Net import DC2Net
import numpy as np
import torch
import torch.optim as optim
from buffer import ReplayBuffer

class DC2NET_ALG():
    def __init__(self, env_args):
        self.env_args = env_args
        self.model = DC2Net(self.env_args.n_ant, self.env_args.obs_space, hidden_dim, self.env_args.n_actions)
        self.model_tar = DC2Net(self.env_args.n_ant, self.env_args.obs_space, hidden_dim, self.env_args.n_actions)
        self.model.cuda()
        self.model_tar.cuda()
        self.model_tar.load_state_dict(self.model.state_dict())
        self.optimizer = optim.RMSprop(self.model.parameters(), lr = 0.0005)
        self.buff = ReplayBuffer(capacity,self.env_args.obs_space,self.env_args.n_actions,self.env_args.n_ant)

    def forward(self, obs, adj, mask, epsilon):
        action = []
        q, reg = self.model(torch.Tensor(np.array([obs])).cuda(), torch.Tensor(np.array([adj])).cuda())
        q = q[0]
        for i in range(self.env_args.n_ant):
            if np.random.rand() < epsilon:
                avail_actions_ind = np.nonzero(mask[i])[0]
                a = np.random.choice(avail_actions_ind)
            else:
                a = np.argmax(q[i].cpu().detach().numpy() - 9e15*(1 - mask[i]))
            action.append(a)
        return action

    def train(self):
        O,A,R,Next_O,Matrix,Next_Matrix,Next_Mask,D = self.buff.getBatch(batch_size)
        q_values, reg = self.model(torch.Tensor(O).cuda(), torch.Tensor(Matrix).cuda())
        target_q_values, reg_ = self.model_tar(torch.Tensor(Next_O).cuda(), torch.Tensor(Next_Matrix).cuda())
        target_q_values = (target_q_values - 9e15*(1 - torch.Tensor(Next_Mask).cuda())).max(dim = 2)[0]
        target_q_values = np.array(target_q_values.cpu().data)
        expected_q = np.array(q_values.cpu().data)
        for j in range(batch_size):
            for i in range(self.env_args.n_ant):
                expected_q[j][i][A[j][i]] = R[j] + (1-D[j])*GAMMA*target_q_values[j][i]
        reg = torch.mean(reg)
        loss = (q_values - torch.Tensor(expected_q).cuda()).pow(2).mean() + 0.001 * reg
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        with torch.no_grad():
            for p, p_targ in zip(self.model.parameters(), self.model_tar.parameters()):
                p_targ.data.mul_(tau)
                p_targ.data.add_((1 - tau) * p.data)
        return loss

    def addbuff(self,obs,action,reward,next_obs,adj,next_adj,mask,terminated):
        self.buff.add(obs,action,reward,next_obs,adj,next_adj,mask,terminated)

    def test(self, test_env):
        test_r = 0
        for _ in range(20):
            test_obs = test_env.reset()
            obs_list = list(test_obs.values())
            test_obs = get_obs(obs_list,self.env_args.n_ant)
            test_adj = np.eye(self.env_args.n_ant)
            test_mask = np.ones([self.env_args.n_ant, self.env_args.n_actions])
            terminated = False
            agent_list = test_env.agents
            while terminated == False:
                action=[]
                q, reg = self.model(torch.Tensor(np.array([test_obs])).cuda(), torch.Tensor(np.array([test_adj])).cuda())
                q = q[0]
                for i in range(self.env_args.n_ant):
                    a = np.argmax(q[i].cpu().detach().numpy() - 9e15*(1 - test_mask[i]))
                    action.append(a)
                    
                action_dict = dict(zip(agent_list, action))      
                next_obs, reward, dones, infos = test_env.step(action_dict)
                reward = sum(list(reward.values()))
                terminated = all(list(dones.values()))
                next_obs_list = list(next_obs.values())
                next_obs = get_obs(next_obs_list,self.env_args.n_ant)
                test_r += reward
                test_obs = next_obs
                
        return test_r/20