import sys

sys.path.append('ALFRED')
sys.path.append('ALFRED/alfred')
sys.path.append('ALFRED/alfred/models')

import os
import numpy as np
import torch
import pprint
import pickle
import json
import networks.CLIP.clip.clip as clip
from alfred.gen import constants

from config import get_config
from alfred.data.preprocess import Dataset

def decompress_mask(compressed_mask):
    '''
    decompress compressed mask array
    '''
    mask = np.zeros((constants.DETECTION_SCREEN_WIDTH, constants.DETECTION_SCREEN_HEIGHT))
    for start_idx, run_len in compressed_mask:
        for idx in range(start_idx, start_idx + run_len):
            mask[idx // constants.DETECTION_SCREEN_WIDTH, idx % constants.DETECTION_SCREEN_HEIGHT] = 1
    return mask

class Preprocesser:
    def __init__(self, args):
        self.args = args
        
    
    def decompress_mask(self, compressed_mask):
        '''
        decompress mask from json files
        '''
        mask = np.array(decompress_mask(compressed_mask))
        mask = np.expand_dims(mask, axis=0)
        return mask
    
    def load_task_json(self, task):
        '''
        load preprocessed json from disk
        '''
        json_path = os.path.join(self.args.processed_data, task['task'], '%s' % self.args.pp_folder, 'ann_%d.json' % task['repeat_idx'])
        with open(json_path) as f:
            data = json.load(f)
        return data

    def data_preprocessing(self, data, name):
        num_of_subgoals = 0.0
        num_of_actions = 0.0
        max_num_of_actions = 0.0
        
        data_preprocessed = []
        
        id2pt = {}
        
        num = 0
        for i, p in enumerate(data):          
            print('{:.2f}'.format(float(i)/len(data)), end='\r')
            d = self.load_task_json(p)
            goal = d['ann']['goal']
            num_of_subgoals += len(d['ann']['instr'])-1
            index_ct = 0
            for idx in range(len(d['ann']['instr'])):
                subgoal = d['ann']['instr'][idx]
                
                if subgoal == '<<stop>>':
                    continue
                
                actions = [j['action'] for j in d['num']['action_low'][idx]]
                valid_interacts = [j['valid_interact'] for j in d['num']['action_low'][idx]]                
                
                action_low_mask = [self.decompress_mask(a['mask']) for a in d['num']['action_low'][idx] if a['mask'] is not None]
                
                steps = len(actions)
                
                num_of_actions += steps
                max_num_of_actions = max(max_num_of_actions, steps)
                indexes = (index_ct, index_ct + steps)
                index_ct += steps
                
                done = np.zeros(steps)
                done[-1] = 1
                
                mc_score = 1
                mc = np.zeros(steps)
                for j in range(1, steps+1):
                    mc[-j] = mc_score
                    mc_score *= 0.99
                
                input_ids = clip.tokenize(subgoal)
                input_ids = input_ids[:, 1].reshape(1, -1)
                empty_ids = torch.zeros((1, 77))
                empty_ids[0][0] = 49406
                empty_ids[0][1:input_ids.shape[1]+1] = input_ids
                empty_ids[0][input_ids.shape[1]+1] = 49407
                
                actions_onehot = torch.zeros(steps, 15)
                for j in range(steps):
                    actions_onehot[j][actions[j]] = 1
                
                reward = torch.tensor(done).reshape(-1, 1)
                valid_interacts = torch.tensor(valid_interacts).reshape(-1, 1)
                action_low_mask = torch.tensor(action_low_mask)
                
                done = torch.tensor(done).reshape(-1, 1)
                mc = torch.tensor(mc).reshape(-1, 1)
                
                data_processed = {
                    'id': num,
                    'subgoal_type': d['plan']['high_pddl'][idx]['discrete_action']['action'],
                    'split': d['split'],
                    'root': d['root'],
                    'goal': goal, 
                    "index": indexes, 
                    "reward": reward,
                    "done": done,
                    'actions': actions_onehot, 
                    "action_low_mask": action_low_mask, 
                    'valid_interacts': valid_interacts,
                    'subgoal': empty_ids, 
                    "mc": mc, 
                }
                
                file_folder = f'ALFRED/data/preprocessed/{name}'
                file_path = file_folder + f'/id_{num}_preprocessed.pt'
                if not os.path.exists(file_folder):
                    os.makedirs(file_folder)
                
                with open(file_path, 'wb') as f:
                    pickle.dump(data_processed, f)
                    
                id2pt[num] = file_path
                
                num += 1
        
        print(num_of_subgoals/len(data))
        print(num_of_actions/num_of_subgoals)
        print(max_num_of_actions)
        
        with open(f'ALFRED/data/preprocessed/{name}_id2pt.pt', 'wb') as f:
            pickle.dump(id2pt, f)
        

if __name__ == '__main__':
    args = get_config()
    
    args.dout = args.dout.format(**vars(args))
    torch.manual_seed(args.seed)

    # check if dataset has been preprocessed
    if not os.path.exists(os.path.join(args.data, "%s.vocab" % args.pp_folder)) and not args.preprocess:
        raise Exception("Dataset not processed; run with --preprocess")

    # make output dir
    pprint.pprint(args)
    if not os.path.isdir(args.dout):
        os.makedirs(args.dout)

    # load train/valid/tests splits
    with open(args.splits) as f:
        splits = json.load(f)
        pprint.pprint({k: len(v) for k, v in splits.items()})
        
    if args.preprocess:
        print("\nPreprocessing dataset and saving to %s folders ... This will take a while. Do this once as required." % args.pp_folder)
        dataset = Dataset(args, None)
        dataset.preprocess_splits(splits)
        
        print("== First step done ==")
    
    train = splits['train']
    valid_seen = splits['valid_seen']
    tests_seen = splits['tests_seen']
    valid_unseen = splits['valid_unseen']
    tests_unseen = splits['tests_unseen']
    
    print(len(valid_seen), len(train))
    valid_seen = valid_seen[:int(len(valid_seen)/5)]
    valid_unseen = valid_unseen[:int(len(valid_unseen)/5)]
    
    # debugging: chose a small fraction of the dataset
    if args.dataset_fraction > 0:
        small_train_size = int(args.dataset_fraction * 0.7)
        small_valid_size = int((args.dataset_fraction * 0.3) / 2)
        train = train[:small_train_size]
        valid_seen = valid_seen[:small_valid_size]
        valid_unseen = valid_unseen[:small_valid_size]

    # debugging: use to check if training loop works without waiting for full epoch
    if args.fast_epoch:
        train = train[:16]
        valid_seen = valid_seen[:16]
        valid_unseen = valid_unseen[:16]
    
     # dump config
    fconfig = os.path.join(args.dout, 'config.json')
    with open(fconfig, 'wt') as f:
        json.dump(vars(args), f, indent=2)

    # display dout
    print("Saving to: %s" % args.dout)
    best_loss = {'train': 1e10, 'valid_seen': 1e10, 'valid_unseen': 1e10}
    train_iter, valid_seen_iter, valid_unseen_iter = 0, 0, 0
    
    preprosser = Preprocesser(args)
    
    preprosser.data_preprocessing(valid_seen, 'valid_seen')
    preprosser.data_preprocessing(train, 'train')
    