import sys
import os
import random
import numpy as np
import torch
import pprint
import json
from tqdm import trange
import pickle
import torch.multiprocessing as mp
from torch.utils.data import DataLoader
from tensorboardX import SummaryWriter

sys.path.append('ALFRED/alfred')
sys.path.append('ALFRED/alfred/models')
os.environ['DISPLAY'] = ':3.0'

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

from utils.data_loader import DatasetOri
from alfred.models.eval.eval_subgoals_no_threads_prev import EvalSubgoals
from agents.agent_alfred_dail import CQLAgentC51, CQLAgentNaive
from config import get_config


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def collate_fn(batch):
    lengths = [data[0].shape[0] for data in batch]
    
    sorted_indices = sorted(range(len(lengths)), key=lengths.__getitem__, reverse=True)
    
    sorted_batch = [batch[i] for i in sorted_indices]
    
    max_length = max(lengths)
    
    batch_size = len(batch)
    
    padded_states = torch.zeros(batch_size, max_length, *(batch[0][0].shape[1:]), device=device)
    padded_actions = torch.zeros(batch_size, max_length, batch[0][1][0].shape[0], device=device)
    padded_rewards = torch.zeros(batch_size, max_length, device=device)
    padded_done = torch.zeros(batch_size, max_length, device=device)
    goals = torch.zeros(batch_size, 1024, device=device)
    
    # goals = [data[4] for data in sorted_batch]
    
    padded_alow_masks = []
    padded_valid_interacts = torch.zeros(batch_size, max_length, device=device)
    padded_mc = torch.zeros(batch_size, max_length)
    
    padded_masks = torch.zeros(batch_size, max_length)
    
    for i, data in enumerate(sorted_batch):
        length = data[0].shape[0]
        
        padded_states[i, :length, :] = data[0].to(device)
        padded_actions[i, :length, :] = data[1].to(device)
        padded_rewards[i, :length] = data[2][:, 0].to(device)
        padded_done[i, :length] = data[3][:, 0].to(device)
        
        if data[5] != None and data[5].shape[0]!=0:
            padded_alow_masks.append(data[5].to(device))
        
        padded_valid_interacts[i, :length] = data[6][:, 0].to(device)
        # padded_mc[i, :length] = data[7][:, 0]
        
        padded_masks[i, :length] = torch.ones(length).to(device)
        
        goals[i, :] = data[4][0].to(device)
    
    if len(padded_alow_masks):
        alow_masks = torch.cat(padded_alow_masks, dim=0)
    else:
        alow_masks = torch.zeros(1)
        
    return (
        padded_states, 
        padded_actions, 
        padded_rewards, 
        padded_done, 
        goals,
        alow_masks, 
        padded_valid_interacts, 
        padded_masks
    )

class Eval:
    def eval(self, config, model, writer, epoch):
        with torch.no_grad():
            config.eval_split = 'train'
            config.subgoals = 'GotoLocation'
            
            train_eval = EvalSubgoals(config, device)
            
            config.eval_split = 'valid_seen'
            config.subgoals = 'GotoLocation'
            
            seen_eval = EvalSubgoals(config, device)
        
            res0, res1 = train_eval.run(model)
            
            writer.add_scalar(f"loss_dict/train_sr", res0, global_step = epoch)
            writer.add_scalar(f"loss_dict/train_srw", res1, global_step = epoch)
            
            res0, res1 = seen_eval.run(model)
            
            writer.add_scalar(f"loss_dict/seen_sr", res0, global_step = epoch)
            writer.add_scalar(f"loss_dict/seen_srw", res1, global_step = epoch)
            

def train(config, model, splits, writer):
    train = splits['train']
    valid_seen = splits['valid_seen']
    valid_unseen = splits['valid_unseen']
    
    valid_seen = valid_seen[:int(len(valid_seen)/20)]
    valid_unseen = valid_unseen[:int(len(valid_unseen)/20)]
    
    # debugging: chose a small fraction of the dataset
    if config.dataset_fraction > 0:
        small_train_size = int(config.dataset_fraction * 0.7)
        small_valid_size = int((config.dataset_fraction * 0.3) / 2)
        train = train[:small_train_size]
        valid_seen = valid_seen[:small_valid_size]
        valid_unseen = valid_unseen[:small_valid_size]

    # debugging: use to check if training loop works without waiting for full epoch
    if config.fast_epoch:
        train = train[:16]
        valid_seen = valid_seen[:16]
        valid_unseen = valid_unseen[:16]

    # dump config
    fconfig = os.path.join(config.dout, 'config.json')
    with open(fconfig, 'wt') as f:
        json.dump(vars(config), f, indent=2)

    # display dout
    print("Saving to: %s" % config.dout)
    
    dataset = DatasetOri(config, train, name='train')
    dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True, collate_fn=collate_fn)
    
    batches = 0
    batches_clip = 0
    
    validation_flag = False
    validation_thread = None
    
    for epoch in trange(0, config.epoch, desc='epoch'):      
        for idx, experience in enumerate(dataloader):
            loss_dict, q = model.learn_step(experience)
            
            loss_dict['epoch'] = epoch
            
            for key in loss_dict.keys():
                writer.add_scalar(f"loss_dict/{key}", loss_dict[key], global_step = batches)
            
            q, _ = torch.max(q.view(-1, 15), dim=0)
            for i in range(len(q)):
                writer.add_scalar(f"value/{i}", q[i], global_step=batches)
            
            batches += 1
            
            if batches < 400000:
                config.eval_every = 1
            else:
                config.eval_every = 10000
                        
            if batches % config.eval_every == 0:
                if validation_flag:
                    print("Waiting...")
                    validation_thread.join()
                    
                save_folder = f'data/{config.model_type}'
                if not os.path.exists(save_folder):
                    os.makedirs(save_folder)
                    
                with open(f'{save_folder}/{config.model_type}_{epoch}.pt', 'wb') as f:
                    pickle.dump(model, f)
                    print("Pickle done.")
                
                model.save_model(f'ALFRED/data/models/{config.model_type}_{config.if_clip}_{batches}_{config.alpha}_{config.if_regularize}_torch.pt', batches)
            
                validation_thread = mp.Process(target=Eval.eval, args=(config, model, writer, batches))
                validation_thread.start()
                validation_flag = True
                
    print("== start to learn q ==")
              
def main(config):
    mp.set_start_method('spawn')
    writer = SummaryWriter(f'ALFRED/data/alfred_lstm/{config.model_type}-{config.if_clip}-{config.alpha}-{config.if_regularize}')
    config.dout = config.dout.format(**vars(config))

    # check if dataset has been preprocessed
    if not os.path.exists(os.path.join(config.data, "%s.vocab" % config.pp_folder)) and not config.preprocess:
        raise Exception("Dataset not processed; run with --preprocess")

    # make output dir
    pprint.pprint(config)
    if not os.path.isdir(config.dout):
        os.makedirs(config.dout)

    # load train/valid/tests splits
    with open(config.splits) as f:
        splits = json.load(f)
        pprint.pprint({k: len(v) for k, v in splits.items()})
    
    if config.model_type == 'C51':
        Agent = CQLAgentC51
    else:
        Agent = CQLAgentNaive
    
    agent = Agent(
        action_size=15,
        hidden_size=config.feature_size,
        device=device,
        config=config
    )

    train(config, agent, splits, writer)
    

if __name__ == '__main__':
    config = get_config()
    config.model_type = 'C51' # choose between `C51` and `CQL` to use distributional RL or vanilla RL
    config.if_clip = True # if alignment is enabled
    
    main(config)