import numpy as np
import random
import torch
from copy import deepcopy
from algs.utilities.OneAgent import OneAgent
from algs.utilities.Buffers.ExperienceReplayBuffer import ExperienceReplayBuffer


class DIDQN():
    def __init__(self, u_q_model, v_q_model, u_noise, v_noise, 
                 gamma=1, q_model_lr=1e-3, tau=1e-2, batch_size=64,
                 memory=ExperienceReplayBuffer(memory_len=100000)):
        self.u_q_model = u_q_model
        self.v_q_model = v_q_model
        self.u_noise = u_noise
        self.v_noise = v_noise
        self.gamma = gamma
        self.tau = tau
        self.batch_size = batch_size
        self.memory = memory
        
        self.u_agent = OneAgent(self.get_u_action, self.u_noise, self.reset)
        self.v_agent = OneAgent(self.get_v_action, self.v_noise, self.reset)
        
        self.u_q_target_model = deepcopy(self.u_q_model)
        self.v_q_target_model = deepcopy(self.v_q_model)
        self.u_optimizer = torch.optim.Adam(self.u_q_model.parameters(), lr=q_model_lr)
        self.v_optimizer = torch.optim.Adam(self.v_q_model.parameters(), lr=q_model_lr)
        return None

    
    def get_u_action(self, state):
        if np.random.uniform(0,1) < self.u_noise.threshold:
            return self.u_noise.get()
        else:
            return np.argmin(self.u_q_model(state).data.numpy())
        
        
    def get_v_action(self, state):
        if np.random.uniform(0,1) < self.v_noise.threshold:
            return self.v_noise.get()
        else:
            return np.argmax(self.v_q_model(state).data.numpy())
    
    
    def fit(self, state, u_action, v_action, reward, done, next_state):
        #add to memory
        self.memory.append([state, u_action, v_action, reward, done, next_state])

        if len(self.memory) >= self.batch_size:
            #get batch
            batch = self.memory.get_batch(self.batch_size)
            states, u_actions, v_actions, rewards, dones, next_states = map(np.array, zip(*batch))
            u_actions = torch.LongTensor(u_actions)
            v_actions = torch.LongTensor(v_actions)
            rewards = torch.FloatTensor(rewards)
            dones = torch.FloatTensor(dones)
            
            #get deltas
            u_q_values = self.u_q_model(states)[torch.arange(self.batch_size), u_actions]
            v_q_values = self.v_q_model(states)[torch.arange(self.batch_size), v_actions]
            u_next_v_values = self.u_q_target_model(next_states).min(dim=1).values
            v_next_v_values = self.v_q_target_model(next_states).max(dim=1).values
            next_v_values = u_next_v_values.detach() +  v_next_v_values.detach()
            u_deltas = rewards + self.gamma * (1 - dones) * next_v_values - u_q_values - v_q_values.detach()
            v_deltas = rewards + self.gamma * (1 - dones) * next_v_values - u_q_values.detach() - v_q_values
            
            u_loss = torch.mean(u_deltas ** 2)
            v_loss = torch.mean(v_deltas ** 2)
            
            if np.random.randint(2) < 0.5:
                self.update_target_model(self.u_q_target_model, self.u_q_model, self.u_optimizer, u_loss)
                self.update_target_model(self.v_q_target_model, self.v_q_model, self.v_optimizer, v_loss)
            else:
                self.update_target_model(self.v_q_target_model, self.v_q_model, self.v_optimizer, v_loss)
                self.update_target_model(self.u_q_target_model, self.u_q_model, self.u_optimizer, u_loss)
        
        self.u_noise.reduce()
        self.v_noise.reduce()
        return None
    
    
    def update_target_model(self, target_model, model, optimizer, loss):
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        for target_param, param in zip(target_model.parameters(), model.parameters()):
            target_param.data.copy_((1 - self.tau) * target_param.data + self.tau * param.data)
        return None
    
    
    def reset(self):
        self.u_noise.reset()
        self.v_noise.reset()
        return None
