import numpy as np
import random
import torch

from copy import deepcopy
from collections import deque
from algs.utilities.LinearTransformations import transform_interval
from algs.utilities.OneAgent import OneAgent


class MADDPG:
    def __init__(self, u_action_min, u_action_max, v_action_min, v_action_max, 
                 u_q_model, v_q_model, u_pi_model, v_pi_model, u_noise, v_noise,
                 q_model_lr=1e-3, pi_model_lr=1e-4, gamma=0.99, batch_size=64, 
                 tau=1e-3, memory_len=6000000):

        self.u_action_min = u_action_min
        self.u_action_max = u_action_max
        self.v_action_min = v_action_min
        self.v_action_max = v_action_max
        
        self.u_q_model = u_q_model
        self.v_q_model = v_q_model
        self.u_pi_model = u_pi_model
        self.v_pi_model = v_pi_model
        self.u_noise = u_noise
        self.v_noise = v_noise
        
        self.gamma = gamma
        self.batch_size = batch_size
        self.tau = tau
        self.memory = deque(maxlen=memory_len)
        
        self.u_agent = OneAgent(self.get_u_action, u_noise, self.reset)
        self.v_agent = OneAgent(self.get_v_action, v_noise, self.reset)
        
        self.u_q_optimizer = torch.optim.Adam(self.u_q_model.parameters(), lr=q_model_lr)
        self.v_q_optimizer = torch.optim.Adam(self.v_q_model.parameters(), lr=q_model_lr)
        self.u_pi_optimizer = torch.optim.Adam(self.u_pi_model.parameters(), lr=pi_model_lr)
        self.v_pi_optimizer = torch.optim.Adam(self.v_pi_model.parameters(), lr=pi_model_lr)
        self.u_q_target_model = deepcopy(self.u_q_model)
        self.v_q_target_model = deepcopy(self.v_q_model)
        self.u_pi_target_model = deepcopy(self.u_pi_model)
        self.v_pi_target_model = deepcopy(self.v_pi_model)
        return None

    
    def get_u_action(self, state):
        action = self.u_pi_model(state).detach().numpy() + self.u_noise.get()
        action = self.u_transform_interval(action)
        return np.clip(action, self.u_action_min, self.u_action_max)
    
    
    def get_v_action(self, state):
        action = self.v_pi_model(state).detach().numpy() + self.v_noise.get()
        action = self.v_transform_interval(action)
        return np.clip(action, self.v_action_min, self.v_action_max)
    
    
    def u_transform_interval(self, action):
        return transform_interval(action, self.u_action_min, self.u_action_max)
    
    
    def v_transform_interval(self, action):
        return transform_interval(action, self.v_action_min, self.v_action_max)
    
    
    def fit(self, state, u_action, v_action, reward, done, next_state):
        #add to memory
        self.memory.append([state, u_action, v_action, reward, done, next_state])
        
        if len(self.memory) >= self.batch_size:
            #get batch
            batch = random.sample(self.memory, self.batch_size)
            splitted_batch = map(np.array, zip(*batch))
            states, u_actions, v_actions, rewards, dones, next_states = map(torch.FloatTensor, splitted_batch)
            rewards = rewards.reshape(self.batch_size, 1)
            dones = dones.reshape(self.batch_size, 1)

            #get predicted next actions
            u_pred_next_actions = self.u_transform_interval(self.u_pi_target_model(next_states))
            v_pred_next_actions = self.v_transform_interval(self.v_pi_target_model(next_states))
            
            #train u_q_model
            next_q_values = self.get_q_values(self.u_q_target_model, next_states, u_pred_next_actions, v_pred_next_actions)
            targets = - rewards + (1 - dones) * self.gamma * next_q_values
            q_values = self.get_q_values(self.u_q_model, states, u_actions, v_actions)
            q_loss = torch.mean((q_values - targets.detach()) ** 2)
            self.update_target_model(self.u_q_target_model, self.u_q_model, self.u_q_optimizer, q_loss)
            
            #train v_q_model
            next_q_values = self.get_q_values(self.v_q_target_model, next_states, u_pred_next_actions, v_pred_next_actions)
            targets = rewards + (1 - dones) * self.gamma * next_q_values
            q_values = self.get_q_values(self.v_q_model, states, u_actions, v_actions)
            q_loss = torch.mean((q_values - targets.detach()) ** 2)
            self.update_target_model(self.v_q_target_model, self.v_q_model, self.v_q_optimizer, q_loss)

            #get predicted actions
            u_pred_actions = self.u_transform_interval(self.u_pi_model(states))
            v_pred_actions = self.v_transform_interval(self.v_pi_model(states))
            
            #train u_pi_model
            q_pred_values = self.get_q_values(self.u_q_model, states, u_pred_actions, v_pred_actions.detach())
            u_pi_loss = - torch.mean(q_pred_values)
            self.update_target_model(self.u_pi_target_model, self.u_pi_model, self.u_pi_optimizer, u_pi_loss)
            
            #train v_pi_model
            q_pred_values = self.get_q_values(self.v_q_model, states, u_pred_actions.detach(), v_pred_actions)
            v_pi_loss = - torch.mean(q_pred_values)
            self.update_target_model(self.v_pi_target_model, self.v_pi_model, self.v_pi_optimizer, v_pi_loss)
            
        self.u_noise.reduce()
        self.v_noise.reduce()
        return None

    
    def get_q_values(self, q_model, states, u_actions, v_actions):
        states_and_actions = torch.cat((states, u_actions, v_actions), dim=1)
        return q_model(states_and_actions)
    
    
    def update_target_model(self, target_model, model, optimizer, loss):
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        for target_param, param in zip(target_model.parameters(), model.parameters()):
            target_param.data.copy_((1 - self.tau) * target_param.data + self.tau * param.data)
        return None

    
    def reset(self):
        self.u_noise.reset()
        self.v_noise.reset()
        return None
    