import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from parl.utils import logger, summary

logger.set_dir('./results_dqn/train_log/')


class Matrix(object):
    def __init__(self):
        self.num_agents = 2
        self.obs_dim = 2
        self.act_dim = 2    #2 or 3

        self.payoff_matrix = np.zeros((self.num_agents, self.act_dim, self.act_dim))
        self._build_matrix()

    def _build_matrix(self):
        
        # self.payoff_matrix[0] = [[2, 0], [4, 6]]
        # self.payoff_matrix[1] = [[2, 4], [0, 6]]
        
        self.payoff_matrix[0] = [[10, 0], [18, 2]]
        self.payoff_matrix[1] = [[10, 18], [0, 2]]

        # self.payoff_matrix[0] = [[20, 0, 0], [30, 10, 0], [0, 0, 5]]
        # self.payoff_matrix[1] = [[15, 0, 0], [0, 5, 0], [0, 0, 10]]
        # self.payoff_matrix[0] = [[17.5, 0, 0], [15, 7.5, 0], [0, 0, 7.5]]
        # self.payoff_matrix[1] = [[17.5, 0, 0], [15, 7.5, 0], [0, 0, 7.5]]

    def reset(self):
        return np.random.uniform(-1, +1, self.obs_dim)

    def step(self, agent_id, actions):
        rew = 0
        if agent_id == 0:
            rew = self.payoff_matrix[0][actions[0], actions[1]]
        else:
            rew = self.payoff_matrix[1][actions[0], actions[1]]

        done = True

        return rew, done


from parl.utils.scheduler import LinearDecayScheduler


TOTAL_STEP = 50000
BATCH_SIZE = 128
LR_START = 0.01
LR_END = 0.001
GAMMA = 0.9
TARGET_REPLACE_ITER = 100           
MEMORY_CAPACITY = 5000              
STAT_RATE = 1000
env = Matrix()
N_ACTIONS = env.act_dim
N_STATES = env.obs_dim


class Net(nn.Module):
    def __init__(self):                                                     
        super(Net, self).__init__()                    

        self.fc1 = nn.Linear(N_STATES, 64)                             
        self.fc1.weight.data.normal_(0, 0.1)                         
        self.out = nn.Linear(64, N_ACTIONS)                            
        self.out.weight.data.normal_(0, 0.1)                    

    def forward(self, x):                                             
        x = F.relu(self.fc1(x))                                        
        actions_value = self.out(x)                                        
        return actions_value                                                


class DQN(object):
    def __init__(self, steps):                                              
        self.eval_net, self.target_net = Net(), Net()                         
        self.learn_step_counter = 0                                         
        self.memory_counter = 0                                              
        self.memory = np.zeros((MEMORY_CAPACITY, N_STATES * 2 + 2))        
        self.lr = LR_START
        self.optimizer = torch.optim.Adam(self.eval_net.parameters(), lr=self.lr)
        self.loss_func = nn.MSELoss()
        
        self.id = steps
        self.mu = np.zeros((BATCH_SIZE, 1))
        self.mu = torch.FloatTensor(self.mu)
        
        self.curr_ep = 1
        self.ep_end = 0.01
        self.lr_end = LR_END

        self.ep_scheduler = LinearDecayScheduler(1, TOTAL_STEP)
        self.lr_scheduler = LinearDecayScheduler(LR_START, TOTAL_STEP)

    def choose_action(self, x):                                             
        explore = np.random.choice([True, False],
                                   p=[self.curr_ep, 1 - self.curr_ep])
        x = torch.unsqueeze(torch.FloatTensor(x), 0)
        if explore:
            action = np.random.randint(0, N_ACTIONS)
        else:
            actions_value = self.eval_net.forward(x)
            action = torch.max(actions_value, 1)[1].data.numpy()
            action = action[0]
        # self.curr_ep = max(self.ep_scheduler.step(1), self.ep_end)

        # x = torch.unsqueeze(torch.FloatTensor(x), 0)                        
        # if np.random.uniform() < EPSILON:                                    
        #     actions_value = self.eval_net.forward(x)                      
        #     action = torch.max(actions_value, 1)[1].data.numpy()            
        #     action = action[0]                                          
        # else:                                                             
        #     action = np.random.randint(0, N_ACTIONS)                  
        return action                                                    

    def store_transition(self, s, a, r, s_):                                
        transition = np.hstack((s, [a, r], s_))                            
        index = self.memory_counter % MEMORY_CAPACITY                    
        self.memory[index, :] = transition                                  
        self.memory_counter += 1                                          

    def calc(self, agent_id, muR, rew1, rew2):
        rewR =  1/2 * (rew1 + rew2)
        eq = np.sign(rew1 - muR) == np.sign(rew2 - muR)
        uneq = ~eq
        if agent_id == 0:
            self.diff = eq * (rew1 - muR) + uneq * (rewR - muR)
        else:
            self.diff = eq * (rew2 - muR) + uneq * (rewR - muR)
        # self.diff = rewR

    
    def learn(self, sample_index):      
        if self.learn_step_counter % TARGET_REPLACE_ITER == 0:             
            self.target_net.load_state_dict(self.eval_net.state_dict())   
        self.learn_step_counter += 1                                         

        # sample_index = np.random.choice(MEMORY_CAPACITY, BATCH_SIZE)        
        b_memory = self.memory[sample_index, :]                              
        b_s = torch.FloatTensor(b_memory[:, :N_STATES])                        
        b_a = torch.LongTensor(b_memory[:, N_STATES:N_STATES+1].astype(int)) 
        b_r = torch.FloatTensor(b_memory[:, N_STATES+1:N_STATES+2])           
        b_s_ = torch.FloatTensor(b_memory[:, -N_STATES:])                      

        q_eval = self.eval_net(b_s).gather(1, b_a)                            
        q_next = self.target_net(b_s_).detach()                               
        # q_target = b_r + GAMMA * q_next.max(1)[0].view(BATCH_SIZE, 1)
        # q_next.max(1)[0]
        
        # q_target = b_r
        # q_target = b_r - self.mu
        q_target = self.diff
        
        loss = self.loss_func(q_eval, q_target)                    
        self.optimizer.zero_grad()                                 
        loss.backward()                                             
        self.optimizer.step()                                      

        # learning rate decay
        self.curr_ep = max(self.ep_scheduler.step(1), self.ep_end)
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = max(self.lr_scheduler.step(1), self.lr_end)

        summary.add_scalar('agent_%d_loss/critic_loss' % self.id, loss, self.learn_step_counter)


dqn1 = DQN(0)
dqn2 = DQN(1)

for steps in range(TOTAL_STEP+MEMORY_CAPACITY):
    mean_rews = []
    s = env.reset()

    a1 = dqn1.choose_action(s)                        
    a2 = dqn2.choose_action(s)
    acts=[]
    acts.append(a1)
    acts.append(a2)

    r1, done = env.step(0, acts)                     
    r2, done = env.step(1, acts)
    reward_rec1 = r1
    reward_rec2 = r2

    s_ = env.reset()

    dqn1.store_transition(s, a1, r1, s_)          
    dqn2.store_transition(s, a2, r2, s_)

    s = s_                                          

    
    if dqn1.memory_counter > MEMORY_CAPACITY:       
        sample_idx1 = np.random.choice(MEMORY_CAPACITY, BATCH_SIZE)
        
        bm1 = dqn1.memory[sample_idx1]
        bm2 = dqn2.memory[sample_idx1]
        br1 = torch.FloatTensor(bm1[:, N_STATES+1:N_STATES+2])
        br2 = torch.FloatTensor(bm2[:, N_STATES+1:N_STATES+2])
        
        muR = 1/2* (dqn1.mu + dqn2.mu)
        dqn1.calc(0, muR, br1, br2)
        dqn2.calc(1, muR, br1, br2)
        
        dqn1.learn(sample_idx1)
        dqn2.learn(sample_idx1)

        dqn1.mu += LR_START * (br1 - dqn1.mu)
        dqn2.mu += LR_START * (br2 - dqn2.mu)


        mean_rews.append(reward_rec1+reward_rec2)
        mean_steps_rew = round(np.mean(mean_rews[-STAT_RATE:]), 2)
        if (steps-MEMORY_CAPACITY) % STAT_RATE == 0:
            summary.add_scalar('mean_reward', mean_steps_rew, (steps-MEMORY_CAPACITY))
            print('---------------------------')
            print('step%s---reward: %s' % ((steps-MEMORY_CAPACITY), np.round(reward_rec1, 2)), '------mu1', np.round(dqn1.mu.mean(), 2))
            print('step%s---reward: %s' % ((steps-MEMORY_CAPACITY), np.round(reward_rec2, 2)), '------mu2', np.round(dqn2.mu.mean(), 2))
  