from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from __future__ import print_function
from collections import defaultdict

import os
from torch.autograd.grad_mode import no_grad
from torch.utils.data import Dataset, dataloader
import torch
import pickle
import networks.CLIP.clip.clip as clip

class DatasetOri(Dataset):
    def __init__(
            self,
            args,
            max_words=16,
            max_frames=40,
            name='train',
    ):
        self.clip, _ = clip.load("RN50", device=args.device)
        self.args = args
        self.name = name
        
        self.data = {}
        self.data['actions'] = []
        self.data['index'] = []
        self.data['reward'] = []
        self.data['done'] = []
        self.data['subgoal'] = []
        self.data['subgoal_instr'] = []
        self.data['split'] = []
        self.data['root'] = []
        self.data['mc'] = []
        
        self.goal_tensor_record = {}
        
        path = f'ALFRED/data/preprocessed/{name}_goto_id2pt_processed.pt'

        if os.path.exists(path) and not args.generate_dataset:
            with open(path, 'rb') as f:
                self.data = pickle.load(f)
            
            with open(f'ALFRED/data/preprocessed/{name}_goal_tensor_dict', 'rb') as f:
                self.goal_tensor_record = pickle.load(f)
        else:
            # filter out the pickup subgoal, and only keeps 2700 different trajectories
            with open(f'ALFRED/data/preprocessed/{name}_id2pt.pt', 'rb') as f:
                data = pickle.load(f)
            
            ct = 0
            self.subgoal_ct = defaultdict(int)

            for idx in data.keys():
                print(idx, end='\r')
                file_path = data[idx].replace('../', '')
                with open(file_path, 'rb') as f:
                    d = pickle.load(f)
            
                if d['subgoal_type'] not in ['GotoLocation']:
                    continue
                else:
                    self.data['actions'].append(d['actions'])
                    self.data['index'].append(d['index'])
                    self.data['reward'].append(d['reward'])
                    self.data['done'].append(d['done'])
                    self.data['subgoal'].append(d['subgoal'])
                    self.data['split'].append(d['split'])
                    self.data['root'].append(d['root'])
                    self.data['mc'].append(d['mc'])

                    goal_ids_str = "_".join([str(int(i)) for i in d['subgoal'][0]])
                    if goal_ids_str not in self.goal_tensor_record:
                        with torch.no_grad():
                            goal_tensor = self.clip.encode_text(d['subgoal'].to(torch.int).to(args.device)).to(torch.float32)
                        self.goal_tensor_record[goal_ids_str] = goal_tensor.cpu().detach()
                        # self.goal_tensor_record[goal_ids_str] = d['subgoal_tensor']

                    # if idx > 50:
                    #     break
            
            if not args.expert:
                if name == 'train':
                    with open('ALFRED/data/additional/id2pt.pt', 'rb') as f:
                        data = pickle.load(f)
                    for idx in data.keys():
                        print(idx, end='\r')

                        self.data['actions'].append(d['actions'])
                        self.data['index'].append(d['index'])
                        self.data['reward'].append(d['reward'])
                        self.data['done'].append(d['done'])
                        self.data['subgoal'].append(d['subgoal'])
                        self.data['split'].append(d['split'])
                        self.data['root'].append(d['root'])
                        self.data['mc'].append(d['mc'])

                        goal_ids_str = "_".join([str(int(i)) for i in d['subgoal'][0]])
                        if goal_ids_str not in self.goal_tensor_record:
                            with torch.no_grad():
                                goal_tensor = self.clip.encode_text(d['subgoal'].to(torch.int).to(args.device)).to(torch.float32)
                            self.goal_tensor_record[goal_ids_str] = goal_tensor.cpu().detach()
                            # self.goal_tensor_record[goal_ids_str] = d['subgoal_tensor']

                        if idx > 30000:
                            break
                                  
            with open(path, 'wb') as f:
                pickle.dump(self.data, f)
            
            with open(f'ALFRED/data/preprocessed/{name}_goal_tensor_dict', 'wb') as f:
                pickle.dump(self.goal_tensor_record, f)
            
            print("New data generated")

        self.feat_pt = 'feat_conv.pt'
        self.max_words = max_words
        self.max_frames = max_frames
        
        # self.size = 8
        self.size = 256
    
    def get_task_root(self, split, root):
        '''
        returns the folder path of a trajectory
        '''
        return os.path.join(self.args.data, split, *(root.split('/')[-2:]))
    
    def __len__(self):
        if self.args.eval_data > 0 and self.name != 'train':
            return min(self.args.eval_data, len(self.data['actions']))
        else:
            return len(self.data['actions'])

    def __getitem__(self, idx):
        root = self.get_task_root(self.data['split'][idx],self.data['root'][idx])
        states = torch.load(os.path.join(root, self.feat_pt))
        states = states[self.data['index'][idx][0]:self.data['index'][idx][1]]
        
        actions = self.data['actions'][idx]
        reward = self.data['reward'][idx]
        done = self.data['done'][idx]
        
        goal_ids_str = "_".join([str(int(i)) for i in self.data['subgoal'][idx][0]])
        goal = self.goal_tensor_record[goal_ids_str]
        
        states = states[:self.max_frames]
        actions = actions[:self.max_frames]
        reward = reward[:self.max_frames]
        done = done[:self.max_frames]
        done[-1] = 1
        
        action_low_mask = torch.tensor([])
        
        valid_interacts = torch.zeros(states.shape[0], 1)
            
        mc = self.data['mc'][idx]
        
        return (states, actions, reward, done, goal, action_low_mask, valid_interacts, mc)

