from os import stat
import torch
import torch.nn.functional as F
from torch.nn.utils import clip_grad_norm_
import numpy as np

from agents.agent_alfred_base import AlfredAgent

class CQLAgentC51(AlfredAgent):
    def __init__(self, action_size, hidden_size=64, device="cpu", config=None):
        super().__init__(action_size, hidden_size=hidden_size, device=device, config=config)
    
    def masked_loss(self, x, y, mask):
        x = x[mask.to(torch.bool)]
        y = y[mask.to(torch.bool)]
        
        loss = -(x * torch.log(y + 1e-6)).sum(-1).mean()

        return loss
    
    def get_m(self, next_values, next_dist, rewards, dones):
        bs = next_values.shape[0]
        seq_len = next_values.shape[1]
        
        next_values = next_values.view(bs * seq_len, -1)
        next_dist = next_dist.view(bs * seq_len, *(next_dist.shape[2:]))        
        rewards = rewards.view(bs * seq_len, -1)
        dones = dones.view(bs * seq_len, -1)
        
        next_actions = torch.argmax(next_values, dim=-1)
            
        next_actions = next_actions[..., None, None].expand(next_actions.shape[0], 1, self.n_atoms)

        next_chosen_dist = next_dist.gather(dim=1, index=next_actions).squeeze(1)
        target_dist = rewards.view(-1, 1).to(torch.float32) + self.gamma * (1-dones.view(-1, 1).to(torch.float32)) * self.atoms
        target_dist.clamp_(self.v_min, self.v_max)
        b = (target_dist - self.v_min) / self.delta_z
        l = b.floor().long()
        u = b.ceil().long()

        l[(l > 0) * (l == u)] -= 1
        u[(u < (self.atoms - 1)) * (l == u)] += 1

        m = torch.zeros(target_dist.size(), dtype=next_chosen_dist.dtype).to(self.device)
        
        offset = torch.linspace(0,
            (next_actions.shape[0] - 1) * self.n_atoms,
            next_actions.shape[0]
        ).long().unsqueeze(1).expand(next_actions.shape[0], self.n_atoms).to(self.device)
        
        m.view(-1).index_add_(0,
                                (l + offset).view(-1),
                                (next_chosen_dist * (u.to(torch.float32) - b)).view(-1)
                                )
        m.view(-1).index_add_(0,
                                (u + offset).view(-1),
                                (next_chosen_dist * (b - l.to(torch.float32))).view(-1)
                                )
        m = m.view(bs, seq_len, -1)
        return m
    
    def get_action(self, state, goals, ht, out=None):
        if out is None:
            out = self.net.get_init_out()
        
        self.net.eval()
        self.state_buffer.append(state)
        # self.state_buffer.append(state)
        
        with torch.no_grad():
            states = torch.stack(list(self.state_buffer), dim=0).to(self.device)
            
            goals_emb = self.net.goal_encoder(goals)
            goals_emb = goals_emb.unsqueeze(1).expand(goals_emb.shape[0], 1, goals_emb.shape[1])
            states = states.to(torch.float32).unsqueeze(0).unsqueeze(0)
            states = self.net.visual_encoder(states)
            action_values_1, action_values_2, _, _ = self.net.get_q_values(states, goals_emb, out)
            action_values = (action_values_1 + action_values_2) / 2
            
        action = np.argmax(action_values.cpu().data.numpy(), axis=2).squeeze(0)
        out, ht = self.net.get_next_ht(states, action, ht)
        self.last_action = int(action)
        self.net.train()
        
        return action, out, ht
    
    def compute_loss(self, states, actions, goals, rewards, dones, masks):
        actions_ = torch.argmax(actions, dim=2, keepdim=True)
        goals_emb = self.net.goal_encoder(goals)
        states_seq = self.net.visual_encoder(states)
        
        # goals_emb = self.get_seq_emb(states, actions, masks)
        goals_emb = goals_emb.unsqueeze(1).expand(goals_emb.shape[0], states.shape[1], goals_emb.shape[1])
        ht = self.net.get_seq_emb(states_seq, actions, masks)
        init_out = self.net.init_out.expand(ht.shape[0], 1, ht.shape[-1])
        ht_ = torch.cat([init_out, ht[:, :-1]], dim=1)

        # ====== WARNING ======
        rewards = rewards
        
        with torch.no_grad():
            next_values_1, next_values_2, next_dist_1, next_dist_2 = self.net.get_q_values(states_seq, goals_emb, ht_, target=True)
            
            next_values_1 = self.get_next_states(next_values_1)
            next_values_2 = self.get_next_states(next_values_2)
            next_dist_1 = self.get_next_states(next_dist_1)
            next_dist_2 = self.get_next_states(next_dist_2)
            
            m_1 = self.get_m(next_values_1, next_dist_1, rewards, dones)
            m_2 = self.get_m(next_values_2, next_dist_2, rewards, dones)
            
            m = torch.min(m_1, m_2)
        
        current_values_1, current_values_2, current_dist_1, current_dist_2 = self.net.get_q_values(states_seq, goals_emb, ht_)
        actions_ = actions_[..., None].expand(*(actions_.shape[:2]), 1, self.n_atoms)
        current_dist_1 = current_dist_1.gather(dim=2, index=actions_).squeeze(2)
        current_dist_2 = current_dist_2.gather(dim=2, index=actions_).squeeze(2)
        # current_dist = current_dist[range(states.shape[0]), actions.long()]
        
        bellman_error_1 = self.masked_loss(m, current_dist_1, masks)
        bellman_error_2 = self.masked_loss(m, current_dist_2, masks)
        
        cql_loss_1 = self.cql_loss(current_values_1, actions, masks)
        cql_loss_2 = self.cql_loss(current_values_2, actions, masks)
        
        metrics = {}

        loss = bellman_error_1 + bellman_error_2 + self.config.alpha * (cql_loss_1 + cql_loss_2)
        
        metrics['loss'] = loss
        metrics['cql_loss'] = (cql_loss_1 + cql_loss_2)/2
        metrics['q_loss'] = (bellman_error_1 + bellman_error_2)/2
        
        return metrics, ht, torch.min(current_values_1, current_values_2).detach()
    
    def learn_step(self, experiences, actor=False):
        metrics = {}
        states, actions, rewards, dones, goals, _, _, masks = experiences
        
        states = states.unsqueeze(2)
        metrics, ht, q = self.compute_loss(
            states, 
            actions, 
            goals, 
            rewards, 
            dones, 
            masks)
        
        if self.config.if_clip:
            clip_loss = self.clip_loss(ht, actions, masks, goals)
            metrics['clip_loss'] = clip_loss
        
        if self.config.if_regularize:
            reg_loss = 0
            for param in self.net.parameters():
                if param.requires_grad:
                    reg_loss += torch.norm(param) ** 2
            metrics['reg_loss'] = reg_loss
            
        # Adam
        self.optimizer_q.zero_grad()
        if self.config.if_clip:
            metrics['loss'] = metrics['loss'] + 0.2 * clip_loss
            # loss = metrics['loss'] + clip_loss
        else:
            metrics['loss'] = metrics['loss']
        
        metrics['loss'].backward()
        clip_grad_norm_(self.net.parameters(), 1.)
        self.optimizer_q.step()
            
        # ------------------- update target network ------------------- #
        if (self.ct+1) % self.config.update_frequency == 0:
            self.net.update()
            
        self.ct += 1
        
        for key in metrics.keys():
            metrics[key] = metrics[key].detach().item()
        return metrics, q
    
    def learn_clip(self, experiences, actor=False):
        metrics = {}
        states, actions, rewards, dones, goals, _, _, masks = experiences
        
        frames_num = self.config.history_frame
        added_frames = torch.zeros((states.shape[0], frames_num-1, *(states.shape[2:])), device=self.device)
        states_added = torch.cat((added_frames, states), dim=1)
        
        stacked_states = torch.zeros((states.shape[0], states.shape[1], frames_num, *(states.shape[2:])), device=self.device)
        
        for i in range(states.shape[1]):
            stacked_states[:, i, :] = states_added[:, i:i+frames_num]
        
        states_seq = self.net.visual_encoder(stacked_states)
        ht = self.net.get_seq_emb(states_seq, actions, masks)
        
        clip_loss = self.clip_loss(ht, actions, masks, goals)
        metrics['clip_loss'] = clip_loss
        
        self.optimizer_q.zero_grad()
        loss = metrics['clip_loss']
        loss.backward()
        clip_grad_norm_(self.net.parameters(), 1.)
        self.optimizer_q.step()
            
        for key in metrics.keys():
            metrics[key] = metrics[key].detach().item()
        return metrics
    
class CQLAgentNaive(CQLAgentC51):
    def __init__(self, action_size, hidden_size=64, device="cpu", config=None):
        super().__init__(action_size, hidden_size=hidden_size, device=device, config=config)
    
    def masked_loss(self, x, y, mask):
        x = x[mask.to(torch.bool)]
        y = y[mask.to(torch.bool)]
        
        loss = F.mse_loss(x, y)

        return loss
    
    def get_action(self, state, goals, ht, out=None):
        if out is None:
            out = self.net.get_init_out()
        
        self.net.eval()
        self.state_buffer.append(state)
        
        with torch.no_grad():
            states = torch.stack(list(self.state_buffer), dim=0).to(self.device)
            
            goals_emb = self.net.goal_encoder(goals)
            
            # goal_tensor = []
            # for j in goals[0]:
            #     if j != 0:
            #         goal_tensor.append(int(j))
            # goals_emb = self.net.goal_encoder([goal_tensor])
            goals_emb = goals_emb.unsqueeze(1).expand(goals_emb.shape[0], 1, goals_emb.shape[1])
            
            states = states.to(torch.float32).unsqueeze(0).unsqueeze(0)
            states = self.net.visual_encoder(states)
            
            action_values_1, action_values_2 = self.net.get_q_values(states, goals_emb, out)
            action_values = (action_values_1 + action_values_2) / 2
            
        action = np.argmax(action_values.cpu().data.numpy(), axis=2).squeeze(0)
        out, ht = self.net.get_next_ht(states, action, ht)
        self.last_action = int(action)
        self.net.train()
        return action, out, ht

    def compute_loss(self, states, actions, goals, rewards, dones, masks):
        actions_ = torch.argmax(actions, dim=2, keepdim=True)
        goals_emb = self.net.goal_encoder(goals)
        states_seq = self.net.visual_encoder(states)
        
        goals_emb = goals_emb.unsqueeze(1).expand(goals_emb.shape[0], states.shape[1], goals_emb.shape[1])
        
        ht = self.net.get_seq_emb(states_seq, actions, masks)
        init_out = self.net.init_out.expand(ht.shape[0], 1, ht.shape[-1])
        ht_ = torch.cat([init_out, ht[:, :-1]], dim=1)
        
        with torch.no_grad():
            Q_targets_1, Q_targets_2 = self.net.get_q_values(states_seq, goals_emb, ht_, target=True)
            
            Q_targets_1 = self.get_next_states(Q_targets_1)
            Q_targets_2 = self.get_next_states(Q_targets_2)
            
            Q_targets = torch.min(Q_targets_1, Q_targets_2)
            
            Q_targets = Q_targets.detach().max(2, keepdim=True)[0]
            
            next_values = rewards.to(torch.float32)\
                + (self.gamma * Q_targets.squeeze(2) * (1 - dones.to(torch.float32)))
            
        Q_1, Q_2 = self.net.get_q_values(states_seq, goals_emb, ht_)
        Q_1_ = torch.gather(Q_1, dim=2, index=actions_).squeeze(dim=2) #(bs,)
        Q_2_ = torch.gather(Q_2, dim=2, index=actions_).squeeze(dim=2) #(bs,)
        
        q_loss_1 = self.masked_loss(Q_1_, next_values, masks)
        q_loss_2 = self.masked_loss(Q_2_, next_values, masks)
        q_loss = q_loss_1 + q_loss_2
        
        cql_loss_1 = self.cql_loss(Q_1, actions, masks)
        cql_loss_2 = self.cql_loss(Q_2, actions, masks)
        
        loss =  q_loss + self.config.alpha * (cql_loss_1 + cql_loss_2)
        
        metrics = {}
        metrics['loss'] = loss
        metrics['cql_loss'] = (cql_loss_1 + cql_loss_2) / 2
        metrics['q_loss'] = q_loss
        
        return metrics, ht, torch.min(Q_1, Q_2).detach()
    
    def load_model(self, path):
        d = torch.load(path)
        self.net.load_state_dict(d['model_state_dict'])
        self.optimizer_q.load_state_dict(d['optimizer_state_dict'])
        return d['batches']
