import torch
import torch.nn as nn
from networks.networks_cql_alfred import Network
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import os
import json
from collections import deque


import collections
from alfred.gen import constants

def decompress_mask(compressed_mask):
    '''
    decompress compressed mask array
    '''
    mask = np.zeros((constants.DETECTION_SCREEN_WIDTH, constants.DETECTION_SCREEN_HEIGHT))
    for start_idx, run_len in compressed_mask:
        for idx in range(start_idx, start_idx + run_len):
            mask[idx // constants.DETECTION_SCREEN_WIDTH, idx % constants.DETECTION_SCREEN_HEIGHT] = 1
    return mask

class CrossEn(nn.Module):
    def __init__(self,):
        super(CrossEn, self).__init__()

    def forward(self, sim_matrix):
        logpt = F.log_softmax(sim_matrix, dim=-1)
        logpt = torch.diag(logpt)
        nce_loss = -logpt
        sim_loss = nce_loss.mean()
        return sim_loss

class AlfredAgent():
    def __init__(self, action_size, device, hidden_size=64, config=None):
        self.config = config
        self.action_size = action_size
        self.device = device
        self.tau = 1e-2
        self.gamma = 0.99
        self.last_action = 0
        
        # sentinel tokens
        self.pad = 0
        self.seg = 1
        self.test_mode = False
        self.max_subgoals = 25
        self.feat_pt = 'feat_conv.pt'
        
        self.last_mask_loss = 0
        
        self.v_max = 2
        self.v_min = -2
        self.n_atoms = config.n_atoms
        
        self.atoms = torch.linspace(self.v_min, self.v_max, self.n_atoms).to(self.device)
        self.delta_z = (self.v_max - self.v_min) / (self.n_atoms - 1)
        
        self.net = Network(atoms=self.atoms,
                        action_size=self.action_size,
                        config = config,
                        hidden_size=hidden_size
                        ).to(self.device)

        self.loss_fct = CrossEn()
        self.bce_with_logits = torch.nn.BCEWithLogitsLoss(reduction='none')
        
        self.optimizer_q = optim.Adam(params=self.net.parameters(), lr=config.q_learning_rate)
        self.optimizer_clip = optim.Adam(params=self.net.parameters(), lr=config.clip_learning_rate)
    
        self.vocab = torch.load(os.path.join(self.config.data, "pp.vocab"), weights_only=False)
        
        self.ct = 0
    
    def reset(self):
        frame_num = self.config.history_frame
        self.state_buffer = deque([], maxlen=frame_num)
        for _ in range(frame_num):
            self.state_buffer.append(torch.zeros(512, 7, 7).to(self.device))
    
    def get_action(self):
        pass
    
    def cql_loss(self, q_values, current_actions, masks):
        q_values = q_values[masks.to(torch.bool)]
        current_action = current_actions[masks.to(torch.bool)]
        
        logsumexp = torch.logsumexp(q_values, dim=1, keepdim=True)
        q_a = (q_values * current_action).sum(dim=1).unsqueeze(1)
    
        return (logsumexp - q_a).mean()
    
    def compute_loss(self, states, next_states, actions, goals, rewards, dones):
        pass
    
    def learn_step(self, experiences, actor=False):
        pass
    
    def seq_add_stack_frames(self, states, actions, masks, goals):
        frames_num = self.config.history_frame
        added_frames = torch.zeros((states.shape[0], frames_num-1, *(states.shape[2:])), device=self.device)
        states_added = torch.cat((added_frames, states), dim=1)
        
        stacked_states = torch.zeros((states.shape[0], states.shape[1], frames_num, *(states.shape[2:])), device=self.device)
        
        for i in range(states.shape[1]):
            stacked_states[:, i, :] = states_added[:, i:i+frames_num]
        
        return stacked_states, actions, masks, goals
    
    def clip_loss(self, states, actions, masks, goals):
        states = states.to(torch.float32)
        
        goals = self.net.goal_encoder(goals)
        
        # For lstm
        masks = (masks.sum(dim=1)-1)[..., None, None].expand(masks.shape[0], 1, states.shape[-1])
        visual_output = states.gather(dim=1, index=masks.to(torch.int64)).squeeze(1)
        # visual_output = self.net.align_net(visual_output)
        
        visual_output = visual_output / (visual_output.norm(dim=-1, keepdim=True) + 1e-8)
        
        goals = goals / (goals.norm(dim=-1, keepdim=True) + 1e-8)

        logit_scale = self.net.logit_scale.exp()
        retrieve_logits = logit_scale * torch.matmul(goals, visual_output.t()) + 1e-8
        
        sim_loss1 = self.loss_fct(retrieve_logits)
        sim_loss2 = self.loss_fct(retrieve_logits.T)
        sim_loss = (sim_loss1 + sim_loss2) / 2
        loss = sim_loss
        
        return loss
    
    def get_next_states(self, states):
        states[:, :-1] = states[:, 1:].clone()
        return states
    
    def learn(self, experience, epoch):
        padded_states, padded_actions, _, _, goals, _, _, _, padded_masks = experience
        clip_loss = self.clip_learn(padded_states, padded_actions, padded_masks, goals)
        q_loss, cql_loss, bellman_loss, mask_loss = self.q_learn(experience)
        
        return {
            "Loss_CLIP": clip_loss,
            "Q loss": q_loss,
            "CQL loss": cql_loss,
            "Bellman loss": bellman_loss,
            "Mask loss": mask_loss,
        }
        
    def flip_tensor(self, tensor, on_zero=1, on_non_zero=0):
        '''
        flip 0 and 1 values in tensor
        '''
        res = tensor.clone()
        res[tensor == 0] = on_zero
        res[tensor != 0] = on_non_zero
        return res
    
    def load_task_json(self, task):
        '''
        load preprocessed json from disk
        '''
        json_path = os.path.join(self.config.data, task['task'], '%s' % self.config.pp_folder, 'ann_%d.json' % task['repeat_idx'])
        with open(json_path) as f:
            data = json.load(f)
        return data
    
    def weighted_mask_loss(self, pred_masks, gt_masks):
        '''
        mask loss that accounts for weight-imbalance between 0 and 1 pixels
        '''
        bce = self.bce_with_logits(pred_masks, gt_masks)
        flipped_mask = self.flip_tensor(gt_masks)
        inside = (bce * gt_masks).sum() / (gt_masks).sum()
        outside = (bce * flipped_mask).sum() / (flipped_mask).sum()
        return inside + outside

    
    def _mean_pooling_for_similarity_visual(self, visual_output, video_mask,):
        video_mask_un = video_mask.to(dtype=torch.float).unsqueeze(-1)
        visual_output = visual_output * video_mask_un
        video_mask_un_sum = torch.sum(video_mask_un, dim=1, dtype=torch.float)
        video_mask_un_sum[video_mask_un_sum == 0.] = 1.
        video_out = torch.sum(visual_output, dim=1) / video_mask_un_sum
        return video_out
    
    def decompress_mask(self, compressed_mask):
        '''
        decompress mask from json files
        '''
        mask = np.array(decompress_mask(compressed_mask))
        mask = np.expand_dims(mask, axis=0)
        return mask
    
    def serialize_lang_action(self, feat):
        '''
        append segmented instr language and low-level actions into single sequences
        '''
        is_serialized = not isinstance(feat['num']['lang_instr'][0], list)
        if not is_serialized:
            feat['num']['lang_instr'] = [word for desc in feat['num']['lang_instr'] for word in desc]
            if not self.test_mode:
                feat['num']['action_low'] = [a for a_group in feat['num']['action_low'] for a in a_group]
    
    def zero_input(self, x, keep_end_token=True):
        '''
        pad input with zeros (used for ablations)
        '''
        end_token = [x[-1]] if keep_end_token else [self.pad]
        return list(np.full_like(x[:-1], self.pad)) + end_token
    
    def get_task_root(self, ex):
        '''
        returns the folder path of a trajectory
        '''
        return os.path.join(self.config.data, ex['split'], *(ex['root'].split('/')[-2:]))
    
    def featurize(self, batch, load_mask=True, load_frames=True):
        '''
        tensorize and pad batch input
        '''
        feat = collections.defaultdict(list)

        for ex in batch:
            ###########
            # auxillary
            ###########

            if not self.test_mode:
                # subgoal completion supervision
                if self.config.subgoal_aux_loss_wt > 0:
                    feat['subgoals_completed'].append(np.array(ex['num']['low_to_high_idx']) / self.max_subgoals)

                # progress monitor supervision
                if self.config.pm_aux_loss_wt > 0:
                    num_actions = len([a for sg in ex['num']['action_low'] for a in sg])
                    subgoal_progress = [(i+1)/float(num_actions) for i in range(num_actions)]
                    feat['subgoal_progress'].append(subgoal_progress)

            #########
            # inputs
            #########

            # serialize segments
            self.serialize_lang_action(ex)
            
            # goal and instr language
            lang_goal, lang_instr = ex['num']['lang_goal'], ex['num']['lang_instr']

            # zero inputs if specified
            lang_goal = self.zero_input(lang_goal) if self.config.zero_goal else lang_goal
            lang_instr = self.zero_input(lang_instr) if self.config.zero_instr else lang_instr

            # append goal + instr
            lang_goal_instr = lang_goal + lang_instr
            feat['lang_goal_instr'].append(lang_goal_instr)

            # load Resnet features from disk
            if load_frames and not self.test_mode:
                root = self.get_task_root(ex)
                im = torch.load(os.path.join(root, self.feat_pt))

                num_low_actions = len(ex['plan']['low_actions']) + 1  # +1 for additional stop action
                num_feat_frames = im.shape[0]

                # Modeling Quickstart (without filler frames)
                if num_low_actions == num_feat_frames:
                    feat['frames'].append(im)

                # Full Dataset (contains filler frames)
                else:
                    keep = [None] * num_low_actions
                    for i, d in enumerate(ex['images']):
                        # only add frames linked with low-level actions (i.e. skip filler frames like smooth rotations and dish washing)
                        if keep[d['low_idx']] is None:
                            keep[d['low_idx']] = im[i]
                    keep[-1] = im[-1]  # stop frame
                    feat['frames'].append(torch.stack(keep, dim=0))

            #########
            # outputs
            #########
            if not self.test_mode:
                # low-level action
                feat['action_low'].append([a['action'] for a in ex['num']['action_low']])

                # low-level action mask
                if load_mask:
                    feat['action_low_mask'].append([self.decompress_mask(a['mask']) for a in ex['num']['action_low'] if a['mask'] is not None])

                # low-level valid interact
                feat['action_low_valid_interact'].append([a['valid_interact'] for a in ex['num']['action_low']])

        return feat
    
    @classmethod
    def has_interaction(cls, action):
        '''
        check if low-level action is interactive
        '''
        non_interact_actions = ['MoveAhead', 'Rotate', 'Look', '<<stop>>', '<<pad>>', '<<seg>>']
        if any(a in action for a in non_interact_actions):
            return False
        else:
            return True
    
    def save_model(self, path, batches):
        torch.save({
            'model_state_dict': self.net.state_dict(),
            'optimizer_state_dict': self.optimizer_q.state_dict(),
            'batches': batches,
        }, path)
    
    def load_model(self, path):
        d = torch.load(path)
        self.net.load_state_dict(d['model_state_dict'])
        self.optimizer_q.load_state_dict(d['optimizer_state_dict'])
        return d['batches']