import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
import torch.distributed as dist

from torch.nn.utils import clip_grad_norm_
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.nn.parallel import DistributedDataParallel as DDP


from agents.agent_alfred_base import AlfredAgent
from networks.networks_iql_alfred import IQLNetwork
from utils.utils import asymmetric_l2_loss, mean_pooling_for_similarity_visual
# from algorithms.agent_babyai import CrossEn

EXP_ADV_MAX = 100.

class IQLAgent(AlfredAgent):
    def __init__(self, action_size, hidden_size=512, device="cpu", config=None) -> None:
        super().__init__(action_size, hidden_size=hidden_size, device=device, config=config)
        self.args = config

        self.net = IQLNetwork(config=config, hidden_size=hidden_size, lstm_hidden_size=hidden_size, action_size=action_size)
        self.net.to(self.device)

        # if config.distributed:
        #     self.net = nn.SyncBatchNorm.convert_sync_batchnorm(self.net).to(self.device)
        #     self.net = DDP(self.net, device_ids=[device], find_unused_parameters=False)


        self.gamma = 0.99
        self.tau = 1.0
        self.tau_l2 = 0.9
        self.beta = 2

        self._on_calls = 0
        self.target_update_interval = 1000

        self.mc_q_weight = 0.8
        # self.align_loss_fct = CrossEn()
        self.loss_func = nn.CrossEntropyLoss(ignore_index=-1)
        self.optimizer = torch.optim.Adam(self.net.parameters(), lr=config.learning_rate)
        # self.scheduler = CosineAnnealingLR(self.optimizer, T_max=config.episodes)

    def get_action(self, state, goal, h_t=None, out=None):
        hidden_states, cell_states = h_t
        self.net.eval()
        with torch.no_grad():
            state = state.to(torch.float32).unsqueeze(0).unsqueeze(0)

            goal = goal.to(self.device)

            _, _, logits, hidden_states, cell_states = self.net(
                state,
                goal,
                hidden_states=hidden_states,
                cell_states=cell_states,
                return_hidden = True
            )

            action = torch.argmax(logits, dim=-1).item()

        return [action], out, (hidden_states, cell_states)

    def get_q_values(self, experiences):
        # states_init, states, actions, rewards, next_states, dones, goals_lang, goals_state, mcs = experiences
        states, actions, rewards, dones, goals, _, _, masks = experiences

        states = states.to(self.device).to(torch.float32)
        actions = actions.to(self.device)
        rewards = rewards.to(self.device)

        next_states = torch.zeros_like(states)
        next_states[:, :-1] = states[:, 1:]

        dones = dones.to(self.device).to(torch.float32).reshape(dones.shape[0], -1)
        goals = goals.to(self.device)
        # masks = masks.to(self.device)

        batch_size = states.shape[0]
        # if self.config.pixel_input:
        #     states = states / 255.0
        #     next_states = next_states / 255.0

        with torch.no_grad():
            Q_values, _, _ = self.net(states, goals)
            Q_values_1, Q_values_2 = Q_values
            actions = torch.argmax(actions, dim=1, keepdim=True)
            Q_values_1 = torch.gather(Q_values_1, dim=1, index=actions).squeeze(dim=1) # (bs,)
        
        return Q_values_1
    
    def soft_update(self, local_model, target_model):
        for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
            target_param.data.copy_(self.tau*local_param.data + (1.0-self.tau)*target_param.data)

    def learn(self, experiences, policy_extract=False, train=True):
        if train:
            self.net.train()
            self.optimizer.zero_grad()
        else:
            self.net.eval()
        
        metrics = {}

        states, actions, rewards, dones, goals, _, _, masks = experiences
        next_states = torch.zeros_like(states)
        next_states[:, :-1] = states[:, 1:]

        states = states.to(self.device).to(torch.float32)
        # actions = actions.to(self.device).to(torch.float32)
        actions = actions.to(self.device)
        goals = goals.to(self.device)
        masks = masks.to(self.device)

        # rewards = rewards.to(self.device).to(torch.float32).reshape(rewards.shape[0], -1)
        # dones = dones.to(self.device).to(torch.float32).reshape(dones.shape[0], -1)

        # batch_size = states.shape[0]

        with torch.no_grad():
            # if self.config.distributed:
            #     Q_targets_1, Q_targets_2 = self.net.module.get_q_values(states, goals, target=True)
            # else:
            Q_targets_1, Q_targets_2 = self.net.get_q_values(states, goals, target=True)

            # print(f"actions.shape: {actions.shape}")
            actions_target = torch.argmax(actions, dim=-1)
            actions_target = actions_target.unsqueeze(-1)  # (64, 31) -> (64, 31, 1)
            # print(f"actions_target.shape: {actions_target.shape}")
            # print(f"Q_targets_1.shape(before): {Q_targets_1.shape}")
            Q_targets_1 = torch.gather(Q_targets_1, dim=2, index=actions_target).squeeze(dim=-1) # (bs,)
            # print(f"Q_targets_1.shape(after): {Q_targets_1.shape}")
            Q_targets_2 = torch.gather(Q_targets_2, dim=2, index=actions_target).squeeze(dim=-1) # (bs,)

            Q_targets = torch.min(Q_targets_1, Q_targets_2)
            
            _, next_state_values, _ = self.net(next_states, goals)
            next_state_values = next_state_values.squeeze(-1)
        
        # Update value function
        _, values, _ = self.net(states, goals)
        values = values.squeeze(-1)  

        # print(f"----Q_targets.shape:{Q_targets.shape}")
        # print(f"----values.shape:{values.shape}")
        advantage = Q_targets - values
        actions = actions.long()
        if not policy_extract:  # Q and V learning
            v_loss = asymmetric_l2_loss(advantage, self.tau_l2)
            metrics['v_loss'] = v_loss.detach().item()

            # print(f"----rewards.shape: {rewards.shape}")
            # print(f"----next_state_values.shape: {next_state_values.shape}")
            # print(f"----dones.shape: {dones.shape}")
            # Update Q function
            targets_value = (rewards + (self.gamma * next_state_values * (1 - dones))).squeeze(dim=1)
            # if self.config.use_mc_help:
            #     # targets_value = (1 - self.mc_q_weight) * targets_value + self.mc_q_weight * mcs
            #     targets_value = torch.max(targets_value, masks)
            
            Q_values, _, _ = self.net(states, goals) #(bs, n_actions)
            Q_a_s_1, Q_a_s_2 = Q_values
            # print(f"-----actions_target.shape: {actions_target.shape}")  
            # print(f"-----Q_a_s_1.shape(before): {Q_a_s_1.shape}")
            Q_a_s_1 = torch.gather(Q_a_s_1, dim=2, index=actions_target).squeeze(dim=-1) #(bs,)
            # print(f"-----Q_a_s_1.shape(after): {Q_a_s_1.shape}")
            Q_a_s_2 = torch.gather(Q_a_s_2, dim=2, index=actions_target).squeeze(dim=-1) #(bs,)

            # print(f"-----targets_value.shape(after): {targets_value.shape}")
            # Q_a_s_1 = Q_a_s_1.squeeze(-1)  # (64, 31, 1) -> (64, 31)
            q_loss_1 = F.mse_loss(Q_a_s_1, targets_value)
            # Q_a_s_2 = Q_a_s_2.squeeze(-1)  # (64, 31, 1) -> (64, 31)
            q_loss_2 = F.mse_loss(Q_a_s_2, targets_value)
            q_loss = q_loss_1 + q_loss_2
            metrics['q_loss(two)'] = q_loss.detach().item()

            loss = q_loss + v_loss
            metrics['total_loss'] = loss.detach().item()

        else:
            # Policy extraction
            exp_advantage = torch.exp(self.beta * advantage.detach()).clamp(max=EXP_ADV_MAX)
            _, _, action_prob = self.net(states, goals)
            action_prob = F.log_softmax(action_prob, dim=2) #(bs, n_actions)
            data_action_prob = torch.gather(action_prob, dim=2, index=actions_target).squeeze(dim=-1) #(bs,)
            actor_loss = - (data_action_prob * exp_advantage).mean()
            loss = actor_loss
            metrics['actor_loss'] = loss.detach().item()
    
        if train:
            self.optimizer.zero_grad()
            loss.backward()
            clip_grad_norm_(self.net.parameters(), 1.)
            self.optimizer.step()
            self._on_calls += 1
            # Update target Q network
            if self._on_calls % self.target_update_interval == 0:
                # if self.config.distributed and dist.get_rank() == 0:
                #     self.soft_update(self.net.module.q_network_1, self.net.module.q_target_network_1)
                #     self.soft_update(self.net.module.q_network_2, self.net.module.q_target_network_2)
                # elif not self.config.distributed:
                self.soft_update(self.net.q_network_1, self.net.q_target_network_1)
                self.soft_update(self.net.q_network_2, self.net.q_target_network_2)

        return metrics
    
    def save_model(self, path, batches):
        torch.save({
            'model_state_dict': self.net.state_dict(),
            'optimizer_state_dict': self.optimizer_q.state_dict(),
            'batches': batches,
        }, path)
