import numpy as np
import random
import torch
from copy import deepcopy
from collections import deque
from algs.utilities.OneAgent import OneAgent


class MADQN():
    def __init__(self, u_q_model, v_q_model, u_noise, v_noise, 
                 gamma=1, u_q_model_lr=1e-3, v_q_model_lr=1e-3, tau=1e-2, batch_size=64, memory_len=10000):
        self.u_q_model = u_q_model
        self.v_q_model = v_q_model
        self.u_noise = u_noise
        self.v_noise = v_noise
        
        self.gamma = gamma
        self.tau = tau
        self.batch_size = batch_size
        self.memory = deque(maxlen=memory_len)
        
        self.u_agent = OneAgent(self.get_u_action, self.u_noise, self.reset)
        self.v_agent = OneAgent(self.get_v_action, self.v_noise, self.reset)
        
        self.u_q_target_model = deepcopy(self.u_q_model)
        self.v_q_target_model = deepcopy(self.v_q_model)
        self.u_optimizer = torch.optim.Adam(self.u_q_model.parameters(), lr=u_q_model_lr)
        self.v_optimizer = torch.optim.Adam(self.v_q_model.parameters(), lr=v_q_model_lr)
        return None

    
    def get_u_action(self, state):
        if np.random.uniform(0,1) < self.u_noise.threshold:
            return self.u_noise.get()
        else:
            q_value = self.u_q_model(state).data.numpy() 
            return np.argmin(np.max(q_value, axis=1))
        
        
    def get_v_action(self, state):
        if np.random.uniform(0,1) < self.v_noise.threshold:
            return self.v_noise.get()
        else:
            q_value = self.v_q_model(state).data.numpy()
            return np.argmax(np.min(q_value, axis=0))
    
    
    def fit(self, state, u_action, v_action, reward, done, next_state):
        #add to memory
        self.memory.append([state, u_action, v_action, reward, done, next_state])

        if len(self.memory) >= self.batch_size:
            #get batch
            batch = random.sample(self.memory, self.batch_size)
            states, u_actions, v_actions, rewards, dones, next_states = map(np.array, zip(*batch))
            u_actions = torch.LongTensor(u_actions)
            v_actions = torch.LongTensor(v_actions)
            rewards = torch.FloatTensor(rewards)
            dones = torch.FloatTensor(dones)
            
            #train models
            self.train_q_models(self.u_q_model, self.u_q_target_model, self.u_optimizer, 
                                      states, u_actions, v_actions, rewards, dones, next_states, var='minmax')
            self.train_q_models(self.v_q_model, self.v_q_target_model, self.v_optimizer,
                                      states, u_actions, v_actions, rewards, dones, next_states, var='maximin')
        
        self.u_noise.reduce()
        self.v_noise.reduce()
        return None
    
    
    def train_q_models(self, q_model, q_target_model, optimizer, 
                       states, u_actions, v_actions, rewards, dones, next_states, var='minmax'):
        
        q_values = q_model(states)
        targets = q_values.clone().detach()
        next_q_values = q_target_model(next_states).data.numpy()
        for i in range(self.batch_size):
            
            if var=='minmax':
                next_v_value = np.min(next_q_values[i].max(axis=1))
            else:
                next_v_value = np.max(next_q_values[i].min(axis=0))
            
            targets[i][u_actions[i]][v_actions[i]] = rewards[i] + self.gamma * (1 - dones[i]) * next_v_value
            
        loss = torch.mean((targets.detach() - q_values) ** 2)
        self.update_target_model(q_target_model, q_model, optimizer, loss)   
            
        return targets
    
    
    def update_target_model(self, target_model, model, optimizer, loss):
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        for target_param, param in zip(target_model.parameters(), model.parameters()):
            target_param.data.copy_((1 - self.tau) * target_param.data + self.tau * param.data)
        return None
    
    
    def reset(self):
        self.u_noise.reset()
        self.v_noise.reset()
        return None
