import numpy as np
import random
import torch
from copy import deepcopy
from algs.utilities.OneAgent import OneAgent
from algs.utilities.Buffers.ExperienceReplayBuffer import ExperienceReplayBuffer


class IDQN():
    def __init__(self, q_model, u_noise, v_noise, 
                 gamma=1, q_model_lr=1e-3, tau=1e-2, batch_size=64, 
                 memory=ExperienceReplayBuffer(memory_len=100000)):
        self.q_model = q_model
        self.u_noise = u_noise
        self.v_noise = v_noise
        
        self.gamma = gamma
        self.tau = tau
        self.batch_size = batch_size
        self.memory = memory
        
        self.u_agent = OneAgent(self.get_u_action, self.u_noise, self.reset)
        self.v_agent = OneAgent(self.get_v_action, self.v_noise, self.reset)
        
        self.q_target_model = deepcopy(self.q_model)
        self.optimizer = torch.optim.Adam(self.q_model.parameters(), lr=q_model_lr)
        return None

    
    def get_u_action(self, state):
        if np.random.uniform(0,1) < self.u_noise.threshold:
            #print(self.u_noise.threshold)
            return self.u_noise.get()
        else:
            q_value = self.q_model(state).data.numpy() 
            return np.argmin(np.max(q_value, axis=1))
        
        
    def get_v_action(self, state):
        if np.random.uniform(0,1) < self.v_noise.threshold:
            return self.v_noise.get()
        else:
            q_value = self.q_model(state).data.numpy()
            return np.argmax(np.min(q_value, axis=0))
    
    
    def fit(self, state, u_action, v_action, reward, done, next_state):
        #add to memory
        self.memory.append([state, u_action, v_action, reward, done, next_state])

        if len(self.memory) >= self.batch_size:
            #get batch
            batch = self.memory.get_batch(self.batch_size)
            states, u_actions, v_actions, rewards, dones, next_states = map(np.array, zip(*batch))
            u_actions = torch.LongTensor(u_actions)
            v_actions = torch.LongTensor(v_actions)
            rewards = torch.FloatTensor(rewards)
            dones = torch.FloatTensor(dones)
            
            #get deltas
            q_values = self.q_model(states)[torch.arange(self.batch_size), u_actions, v_actions]
            next_q_values = self.q_target_model(next_states).detach()
            min_max_next_v_values = next_q_values.max(dim=2).values.min(dim=1).values
            max_min_next_v_values = next_q_values.min(dim=1).values.max(dim=1).values
            next_v_values = 0.5 * (min_max_next_v_values + max_min_next_v_values)
            deltas = rewards + self.gamma * (1 - dones) * next_v_values - q_values
            
            #get loss
            loss = torch.mean(deltas ** 2)
            
            #update model
            self.update_target_model(self.q_target_model, self.q_model, self.optimizer, loss)
            
        self.u_noise.reduce()
        self.v_noise.reduce()
        return None
    
    
    def update_target_model(self, target_model, model, optimizer, loss):
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        for target_param, param in zip(target_model.parameters(), model.parameters()):
            target_param.data.copy_((1 - self.tau) * target_param.data + self.tau * param.data)
        return None
    
    
    def reset(self):
        self.u_noise.reset()
        self.v_noise.reset()
        return None
