import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import clip_grad_norm_

from agents.agent_alfred_base import AlfredAgent
from networks.networks_grif_alfred import GRIFNetwork

class CrossEn(nn.Module):
    def __init__(self,):
        super(CrossEn, self).__init__()

    def forward(self, sim_matrix):
        logpt = F.log_softmax(sim_matrix, dim=-1)
        logpt = torch.diag(logpt)
        nce_loss = -logpt
        sim_loss = nce_loss.mean()
        return sim_loss


class GRIFAgent(AlfredAgent):
    def __init__(self, action_size, hidden_size=512, device="cpu", config=None):
        super().__init__(action_size, hidden_size=hidden_size, device=device, config=config)
        
        self.args = config

        self.net = GRIFNetwork(config=config, hidden_size=hidden_size, lstm_hidden_size=hidden_size, action_size=action_size)
        self.net.to(self.device)

        self.loss_func = nn.CrossEntropyLoss(ignore_index=-1)
        self.align_loss_fct = CrossEn()
        self.optimizer = torch.optim.Adam(self.net.parameters(), lr=config.learning_rate)

        self.align_weight = 0.5

    def get_action(self, state, goals, h_t=None, out=None):
        self.net.eval()
        hidden_states, cell_states = h_t
        with torch.no_grad():
            state = state.to(torch.float32).unsqueeze(0).unsqueeze(0)
            
            with torch.no_grad():
                action_prob, hidden_states, cell_states = self.net.get_action(state, goals, 
                                                                   hidden_states, 
                                                                   cell_states, 
                                                                   single_step=True)

            action = torch.argmax(action_prob, dim=-1).item()

        return [action], out, (hidden_states, cell_states)

    def learn(self, experiences, train=True):
        metrics = {}

        states, actions, rewards, dones, goals, _, _, masks = experiences
       
        states = states.to(self.device).to(torch.float32)
        actions = actions.to(self.device)
        goals = goals.to(self.device)
        masks = masks.to(self.device)
        lengths = torch.sum(masks, dim=1)

        if train:
            self.net.train()
            self.optimizer.zero_grad()
        else:
            self.net.eval()
        
        # -------------------Aligning Loss-----------------
        unique_goals_lang = torch.unique(goals, dim=0, sorted=True)
        unique_positions = torch.tensor([torch.nonzero((goals == element).all(dim=1))[0][0] for element in unique_goals_lang])

        unique_states_seq = states[unique_positions]
        unique_goals_lang_clip = goals[unique_positions]
        unique_lengths = lengths[unique_positions]
        actions = actions[unique_positions]
        masks = masks[unique_positions]

        video_embs, lang_embs, logit_scale = self.net.get_prior_lang_emb(unique_states_seq, unique_lengths, unique_goals_lang_clip)
        video_embs = video_embs / (video_embs.norm(dim=-1, keepdim=True) + 1e-6)
        lang_embs = lang_embs / (lang_embs.norm(dim=-1, keepdim=True) + 1e-6)

        retrieve_logits = logit_scale.exp() * torch.matmul(video_embs, lang_embs.T)

        sim_loss1 = self.align_loss_fct(retrieve_logits)
        sim_loss2 = self.align_loss_fct(retrieve_logits.T)
        sim_loss = self.align_weight * (sim_loss1 + sim_loss2) / 2
        metrics['align_loss'] = sim_loss.item()


        # -------------------Action Loss-------------------
        lang_policy_output, prior_policy_output = self.net(unique_states_seq, unique_goals_lang_clip, unique_lengths, single_step=False)
        actions_pred_lang, _, _ = lang_policy_output
        actions_pred_prior, _, _ = prior_policy_output
        actions_target = torch.argmax(actions, dim=-1)
        
        row_idx = torch.arange(actions_target.size(0)).unsqueeze(1)
        col_idx = torch.arange(actions_target.size(1)).unsqueeze(0).to(self.device)
        mask = col_idx >= unique_lengths.unsqueeze(1)
        actions_target.masked_fill_(mask, -1)

        policy_loss_lang = self.loss_func(actions_pred_lang.reshape(-1, actions_pred_lang.shape[2]),
                                           actions_target.view(-1))
        policy_loss_prior = self.loss_func(actions_pred_prior.reshape(-1, actions_pred_prior.shape[2]),
                                            actions_target.view(-1))
        metrics['policy_loss_lang'] = policy_loss_lang.item()
        metrics['policy_loss_prior'] = policy_loss_prior.item()
        
        loss = policy_loss_lang + policy_loss_prior + sim_loss
        metrics['total_loss'] = loss.item()

        if train:
            loss.backward()
            clip_grad_norm_(self.net.parameters(), 1.0)
            self.optimizer.step()

        return metrics
    
    def save_model(self, path, batches):
        torch.save({
            'model_state_dict': self.net.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'batches': batches,
        }, path)
    
    def load_model(self, path):
        d = torch.load(path)
        self.net.load_state_dict(d['model_state_dict'])
        self.optimizer.load_state_dict(d['optimizer_state_dict'])
        return d['batches']