import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import clip_grad_norm_

from agents.agent_alfred_base import AlfredAgent
from networks.networks_bcz_alfred import BCZNetwork


class BCZAgent(AlfredAgent):
    def __init__(self, action_size, hidden_size=512, device="cpu", config=None):
        super().__init__(action_size, hidden_size=hidden_size, device=device, config=config)
        
        self.args = config

        self.net = BCZNetwork(config=config, hidden_size=hidden_size, lstm_hidden_size=hidden_size, action_size=action_size)
        self.net.to(self.device)

        self.loss_func = nn.CrossEntropyLoss(ignore_index=-1)  
        self.video_loss_func = nn.CosineSimilarity()
        self.optimizer = torch.optim.Adam(self.net.parameters(), lr=config.learning_rate)

    def get_action(self, state, goals, h_t=None, out=None):
        self.net.eval()
        hidden_states, cell_states = h_t
        with torch.no_grad():
            state = state.to(torch.float32).unsqueeze(0).unsqueeze(0)
            logits, hidden_states, cell_states = self.net(
                state,
                goals,
                hidden_states=hidden_states,
                cell_states=cell_states,
                single_step=True
            )

            action = torch.argmax(logits, dim=-1).item()

        return [action], out, (hidden_states, cell_states)

    def learn(self, experiences, train=True):
        metrics = {}

        states, actions, rewards, dones, goals, _, _, masks = experiences

        states = states.to(self.device).to(torch.float32)
        actions = actions.to(self.device)
        goals = goals.to(self.device)
        masks = masks.to(self.device)

        if train:
            self.net.train()
            self.optimizer.zero_grad()
        else:
            self.net.eval()
        
        logits, _, _ = self.net(states, goals, single_step=False)

        actions_target = torch.argmax(actions, dim=-1)  # (batch_size, seq_len)

        lengths = torch.sum(masks, dim=1)
        row_idx = torch.arange(actions_target.size(0)).unsqueeze(1)
        col_idx = torch.arange(actions_target.size(1)).unsqueeze(0).to(self.device)
        mask = col_idx >= lengths.unsqueeze(1)
        actions_target.masked_fill_(mask, -1)

        policy_loss = self.loss_func(logits.view(-1, logits.shape[-1]), actions_target.view(-1))
        metrics['policy_loss'] = policy_loss.item()

        # ---------- video loss ------------
        video_emb, lang_emb = self.net.get_video_lang_emb(states, masks, goals)
        video_loss = 1 - self.video_loss_func(video_emb, lang_emb).mean()
        metrics['video_loss'] = video_loss.item()
        
        loss = policy_loss + video_loss
        metrics['total_loss'] = loss.item()
        if train:
            loss.backward()
            clip_grad_norm_(self.net.parameters(), 1.0)
            self.optimizer.step()

        return metrics
    
    def save_model(self, path, batches):
        torch.save({
            'model_state_dict': self.net.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'batches': batches,
        }, path)
    
    def load_model(self, path):
        d = torch.load(path)
        self.net.load_state_dict(d['model_state_dict'])
        self.optimizer.load_state_dict(d['optimizer_state_dict'])
        return d['batches']