import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from parl.utils import logger, summary

logger.set_dir('./results_dqn/train_log/')


class Matrix(object):
    def __init__(self):
        self.num_agents = 3
        self.obs_dim = 2
        # self.act_dim = 3    #2 or 3

        self.payoff_matrix = np.zeros((self.num_agents, 3, 2, 3))
        self._build_matrix()

    def _build_matrix(self):
        ### Optimum (7,7,6)
        self.payoff_matrix[0] = [[[-7, 6, 6], [9, 0, 0]],
                                  [[-1, 6, -2], [6, 7, 6]],
                                  [[-1, -1, 5], [-1, 0, 5]]]
        self.payoff_matrix[1] = [[[-7, -2, -2], [9, 6, -2]],
                                  [[5, 6, -2], [0, 7, -2]],
                                  [[5, -2, -2], [-2, 6, -2]]]
        self.payoff_matrix[2] = [[[10, 6, 6], [-4, 6, -2]],
                                  [[6, 0, 0], [6, 6, 6]],
                                  [[6, 0, 0], [0, 6, -4]]]

    def reset(self):
        return np.random.uniform(-1, +1, self.obs_dim)

    def step(self, agent_id, actions):
        rew = 0
        if agent_id == 0:
            rew = self.payoff_matrix[0][actions[0], actions[1], actions[2]]
        elif agent_id == 1:
            rew = self.payoff_matrix[1][actions[0], actions[1], actions[2]]
        elif agent_id == 2:
            rew = self.payoff_matrix[2][actions[0], actions[1], actions[2]]

        done = True

        return rew, done


from parl.utils.scheduler import LinearDecayScheduler


TOTAL_STEP = 50000
BATCH_SIZE = 128
LR_START = 0.01
LR_END = 0.001
GAMMA = 0.9
TARGET_REPLACE_ITER = 100                 
MEMORY_CAPACITY = 5000              
STAT_RATE = 1000
env = Matrix()
# N_ACTIONS = env.act_dim
N_STATES = env.obs_dim


class Net(nn.Module):
    def __init__(self, n_actions):                                            
        super(Net, self).__init__()                                          

        self.fc1 = nn.Linear(N_STATES, 64)                                 
        self.fc1.weight.data.normal_(0, 0.1)                          
        self.out = nn.Linear(64, n_actions)                              
        self.out.weight.data.normal_(0, 0.1)                             

    def forward(self, x):                                               
        x = F.relu(self.fc1(x))                                         
        actions_value = self.out(x)                                       
        return actions_value                                                 


class DQN(object):
    def __init__(self, steps):                                                    
        if steps == 1:
            self.eval_net, self.target_net = Net(2), Net(2)                          
        else:
            self.eval_net, self.target_net = Net(3), Net(3)
        self.learn_step_counter = 0                                             # for target updating
        self.memory_counter = 0                                                 # for storing memory
        self.memory = np.zeros((MEMORY_CAPACITY, N_STATES * 2 + 2))            
        self.lr = LR_START
        self.optimizer = torch.optim.Adam(self.eval_net.parameters(), lr=self.lr)
        self.loss_func = nn.MSELoss()
        
        self.id = steps
        self.n_actions = 2 if self.id==1 else 3
        self.mu = np.zeros((BATCH_SIZE, 1))
        self.mu = torch.FloatTensor(self.mu)
        
        self.curr_ep = 1
        self.ep_end = 0.01
        self.lr_end = LR_END

        self.ep_scheduler = LinearDecayScheduler(1, TOTAL_STEP)
        self.lr_scheduler = LinearDecayScheduler(LR_START, TOTAL_STEP)

    def choose_action(self, x):                                              
        explore = np.random.choice([True, False],
                                   p=[self.curr_ep, 1 - self.curr_ep])
        x = torch.unsqueeze(torch.FloatTensor(x), 0)
        if explore:
            action = np.random.randint(0, self.n_actions)
        else:
            actions_value = self.eval_net.forward(x)
            action = torch.max(actions_value, 1)[1].data.numpy()
            action = action[0]
        return action                                                      

    def store_transition(self, s, a, r, s_):                                
        transition = np.hstack((s, [a, r], s_))                          
        index = self.memory_counter % MEMORY_CAPACITY                         
        self.memory[index, :] = transition                                    
        self.memory_counter += 1                                       

    def calc(self, agent_id, muR, rew1, rew2, rew3):
        rewR =  1/3 * (rew1 + rew2 + rew3)
        if agent_id == 0:
            eq = np.sign(rew1 - muR) == np.sign((rew2+rew3)/2 - muR)
            uneq = ~eq
            self.diff = eq * (rew1 - muR) + uneq * (rewR - muR)
        elif agent_id == 1:
            eq = np.sign(rew2 - muR) == np.sign((rew1+rew3)/2 - muR)
            uneq = ~eq
            self.diff = eq * (rew2 - muR) + uneq * (rewR - muR)
        elif agent_id == 2:
            eq = np.sign(rew3 - muR) == np.sign((rew1+rew2)/2 - muR)
            uneq = ~eq
            self.diff = eq * (rew3 - muR) + uneq * (rewR - muR)
        # self.diff = rewR
    
    def learn(self, agent_id, sample_index):                                      
        if self.learn_step_counter % TARGET_REPLACE_ITER == 0:            
            self.target_net.load_state_dict(self.eval_net.state_dict())      
        self.learn_step_counter += 1                                       

        # sample_index = np.random.choice(MEMORY_CAPACITY, BATCH_SIZE)       
        b_memory = self.memory[sample_index, :]                                 
        b_s = torch.FloatTensor(b_memory[:, :N_STATES])                       
        b_a = torch.LongTensor(b_memory[:, N_STATES:N_STATES+1].astype(int))  
        # print(self.id, b_a)
        b_r = torch.FloatTensor(b_memory[:, N_STATES+1:N_STATES+2])      
        b_s_ = torch.FloatTensor(b_memory[:, -N_STATES:])                   


        q_eval = self.eval_net(b_s).gather(1, b_a)                  
        q_next = self.target_net(b_s_).detach()                               
                
        # q_target = b_r
        # q_target = b_r - self.mu
        q_target = self.diff
        
        loss = self.loss_func(q_eval, q_target)
        self.optimizer.zero_grad()                                 
        loss.backward()                  
        self.optimizer.step()                    

        # learning rate decay
        self.curr_ep = max(self.ep_scheduler.step(1), self.ep_end)
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = max(self.lr_scheduler.step(1), self.lr_end)

        summary.add_scalar('agent_%d_loss/critic_loss' % self.id, loss, self.learn_step_counter)


dqn1 = DQN(0)
dqn2 = DQN(1)
dqn3 = DQN(2)

for steps in range(TOTAL_STEP+MEMORY_CAPACITY):
    mean_rews = []
    s = env.reset()

    a1 = dqn1.choose_action(s)
    a2 = dqn2.choose_action(s)
    a3 = dqn3.choose_action(s)
    acts=[]
    acts.append(a1)
    acts.append(a2)
    acts.append(a3)

    r1, done = env.step(0, acts)
    r2, done = env.step(1, acts)
    r3, done = env.step(2, acts)
    reward_rec1 = r1
    reward_rec2 = r2
    reward_rec3 = r3

    s_ = env.reset()

    dqn1.store_transition(s, a1, r1, s_)
    dqn2.store_transition(s, a2, r2, s_)
    dqn3.store_transition(s, a3, r3, s_)

    s = s_

    
    if dqn1.memory_counter > MEMORY_CAPACITY:
        sample_idx1 = np.random.choice(MEMORY_CAPACITY, BATCH_SIZE)
        
        bm1 = dqn1.memory[sample_idx1]
        bm2 = dqn2.memory[sample_idx1]
        bm3 = dqn3.memory[sample_idx1]
        br1 = torch.FloatTensor(bm1[:, N_STATES+1:N_STATES+2])
        br2 = torch.FloatTensor(bm2[:, N_STATES+1:N_STATES+2])
        br3 = torch.FloatTensor(bm3[:, N_STATES+1:N_STATES+2])
        
        muR = 1/3* (dqn1.mu + dqn2.mu + dqn3.mu)
        dqn1.calc(0, muR, br1, br2, br3)
        dqn2.calc(1, muR, br1, br2, br3)
        dqn3.calc(2, muR, br1, br2, br3)

        
        dqn1.learn(0, sample_idx1)
        dqn2.learn(1, sample_idx1)
        dqn3.learn(2, sample_idx1)

        dqn1.mu += LR_START * (br1 - dqn1.mu)
        dqn2.mu += LR_START * (br2 - dqn2.mu)
        dqn3.mu += LR_START * (br3 - dqn3.mu)


        mean_rews.append(reward_rec1+reward_rec2+reward_rec3)
        mean_steps_rew = round(np.mean(mean_rews[-STAT_RATE:]), 2)
        if (steps-MEMORY_CAPACITY) % STAT_RATE == 0:
            summary.add_scalar('mean_reward', mean_steps_rew, (steps-MEMORY_CAPACITY))
            print('---------------------------')
            print('step%s---reward: %s' % ((steps-MEMORY_CAPACITY), np.round(reward_rec1, 2)), '------mu1', np.round(dqn1.mu.mean(), 2))
            print('step%s---reward: %s' % ((steps-MEMORY_CAPACITY), np.round(reward_rec2, 2)), '------mu2', np.round(dqn2.mu.mean(), 2))
            print('step%s---reward: %s' % ((steps-MEMORY_CAPACITY), np.round(reward_rec3, 2)), '------mu3', np.round(dqn3.mu.mean(), 2))
            