import numpy as np
import random
import torch
from copy import deepcopy
from collections import deque
from algs.utilities.OneAgent import OneAgent


class CounterDQN():
    def __init__(self, q_model, u_noise, v_noise, pure_agent, 
                 gamma=1, q_model_lr=1e-3, tau=1e-2, batch_size=64, memory_len=10000):
        self.q_model = q_model
        self.u_noise = u_noise
        self.v_noise = v_noise
        self.pure_agent = pure_agent
        
        self.gamma = gamma
        self.tau = tau
        self.batch_size = batch_size
        self.memory = deque(maxlen=memory_len)
        
        self.u_agent = OneAgent(self.get_u_action, self.u_noise, self.reset)
        self.v_agent = OneAgent(self.get_v_action, self.v_noise, self.reset)
        
        self.q_target_model = deepcopy(self.q_model)
        self.optimizer = torch.optim.Adam(self.q_model.parameters(), lr=q_model_lr)
        
        self.v_action = None
        return None

    
    def get_u_action(self, state):
        if self.pure_agent == 'u':
            self.u_action = self.get_u_pure_action(state)
            return self.u_action
        elif self.pure_agent == 'v':
            self.v_action = self.get_v_pure_action(state)
            q_value = self.q_model(state).data.numpy() 
            return np.argmin(q_value[:,self.v_action])
    
    
    def get_v_action(self, state):
        if self.pure_agent == 'u':
            q_value = self.q_model(state).data.numpy() 
            return np.argmax(q_value[self.u_action])
        elif self.pure_agent == 'v':
            if self.v_action != None:
                return self.v_action
            else:
                return self.get_v_pure_action(state)
    
    
    def get_u_pure_action(self, state):
        if np.random.uniform(0,1) < self.u_noise.threshold:
            return self.u_noise.get()
        else:
            q_value = self.q_model(state).data.numpy() 
            return np.argmin(np.max(q_value, axis=1))
    
    
    def get_v_pure_action(self, state):
        if np.random.uniform(0,1) < self.v_noise.threshold:
            return self.v_noise.get()
        else:
            q_value = self.q_model(state).data.numpy()
            return np.argmax(np.min(q_value, axis=0))
    
    
    def fit(self, state, u_action, v_action, reward, done, next_state):
        #add to memory
        self.memory.append([state, u_action, v_action, reward, done, next_state])

        if len(self.memory) >= self.batch_size:
            #get batch
            batch = random.sample(self.memory, self.batch_size)
            states, u_actions, v_actions, rewards, dones, next_states = map(np.array, zip(*batch))
            u_actions = torch.LongTensor(u_actions)
            v_actions = torch.LongTensor(v_actions)
            rewards = torch.FloatTensor(rewards)
            dones = torch.FloatTensor(dones)
            
            #train models
            q_values = self.q_model(states)[torch.arange(self.batch_size), u_actions, v_actions]
            next_q_values = self.q_target_model(next_states).detach()
            
            if self.pure_agent == 'u':
                next_v_values = next_q_values.max(dim=2).values.min(dim=1).values
            elif self.pure_agent == 'v':
                next_v_values = next_q_values.min(dim=1).values.max(dim=1).values
            
            deltas = rewards + self.gamma * (1 - dones) * next_v_values.detach() - q_values
            loss = torch.mean(deltas ** 2)

            self.update_target_model(self.q_target_model, self.q_model, self.optimizer, loss)
            
        self.u_noise.reduce()
        self.v_noise.reduce()
        return None
    
    
    def update_target_model(self, target_model, model, optimizer, loss):
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        for target_param, param in zip(target_model.parameters(), model.parameters()):
            target_param.data.copy_((1 - self.tau) * target_param.data + self.tau * param.data)
        return None
    
    
    def reset(self):
        self.u_noise.reset()
        self.v_noise.reset()
        return None
