import torch
import sys
import os
import copy
import random

sys.path.append('ALFRED/alfred')
sys.path.append('ALFRED/alfred/models')
os.environ['DISPLAY'] = ':3.0'

import pprint
import json
import torch.multiprocessing as mp
from tensorboardX import SummaryWriter
import numpy as np

from agents.agent_alfred_dail import CQLAgentC51, CQLAgentNaive
from alfred.models.eval.eval_subgoals_threads_prev import EvalSubgoals
from config import get_config

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'

def collate_fn(batch):
    lengths = [data[0].shape[0] for data in batch]
    
    max_length = max(lengths)
    
    batch_size = len(batch)
    
    padded_states = torch.zeros(batch_size, max_length, *(batch[0][0].shape[1:]), device=device)
    padded_actions = torch.zeros(batch_size, max_length, batch[0][1][0].shape[0], device=device)
    padded_rewards = torch.zeros(batch_size, max_length, device=device)
    padded_done = torch.zeros(batch_size, max_length, device=device)
    goals = torch.zeros(batch_size, 1024, device=device)
    
    padded_alow_masks = []
    padded_valid_interacts = torch.zeros(batch_size, max_length, device=device)
    
    padded_masks = torch.zeros(batch_size, max_length, device=device)
    
    for i, data in enumerate(batch):
        length = data[0].shape[0]
        
        padded_states[i, :length, :] = data[0].to(device)
        padded_actions[i, :length, :] = data[1].to(device)
        padded_rewards[i, :length] = data[2][:, 0].to(device)
        padded_done[i, :length] = data[3][:, 0].to(device)
        
        if data[5] != None and data[5].shape[0]!=0:
            padded_alow_masks.append(data[5].to(device))
        
        padded_valid_interacts[i, :length] = data[6][:, 0].to(device)
        
        padded_masks[i, :length] = torch.ones(length, device=device)
        
        goals[i, :] = data[4][0].to(device)
    
    if len(padded_alow_masks):
        alow_masks = torch.cat(padded_alow_masks, dim=0)
    else:
        alow_masks = torch.zeros(1)
        
    return (
        padded_states,
        padded_actions, 
        padded_rewards, 
        padded_done, 
        goals,
        alow_masks, 
        padded_valid_interacts, 
        padded_masks
    )

def eval(config, model, dataloader, name, writer, epoch):
    loss_dict, count_action = model.eval(config, dataloader, name)
    for key in loss_dict.keys():
        writer.add_scalar(f"loss_dict/{key}", loss_dict[key], global_step=epoch)
    
    for key in count_action.keys():
        writer.add_scalar(f'action_acc_{name}/{str(key)}', count_action[key], global_step=epoch)

class Eval:
    def __init__(self, split, config) -> None:
        self.config = copy.deepcopy(config)
        self.split = split
        if split == 'train':
            self.config.num_eval_file = 200
        self.flag = False
    
    def eval_real(self, model, batches):
        model = copy.deepcopy(model)
        with torch.no_grad():
            manager = mp.Manager()
            self.config.eval_split = self.split
            self.config.subgoals = 'GotoLocation'
            self.batches = batches
            
            self.train_eval = EvalSubgoals(self.config, model, manager)
            self.train_eval.spawn_threads()
            
            self.flag = True
                
    def join_threads(self, epoch, writer):
        if self.flag:
            res0, res1 = self.train_eval.join_threads()
            writer.add_scalar(f"eval/{self.split}_sr", res0, global_step = epoch)
            writer.add_scalar(f"eval/{self.split}_srw", res1, global_step = epoch)
        
        return res0, res1

def eval(config, epoch):
    device = torch.device(config.device)
    
    mp.set_start_method('spawn', force=True)
    
    config.dout = config.dout.format(**vars(config))
    
    writer = SummaryWriter(f'ALFRED/data/alfred/{config.model_type}-{config.if_clip}-{config.alpha}-{config.if_regularize}-{config.seed}')

    # check if dataset has been preprocessed
    if not os.path.exists(os.path.join(config.data, "%s.vocab" % config.pp_folder)) and not config.preprocess:
        raise Exception("Dataset not processed; run with --preprocess")

    # make output dir
    pprint.pprint(config)
    if not os.path.isdir(config.dout):
        os.makedirs(config.dout)

    # load train/valid/tests splits
    with open(config.splits) as f:
        splits = json.load(f)
        pprint.pprint({k: len(v) for k, v in splits.items()})
    
    if config.model_type == 'C51':
        Agent = CQLAgentC51
    else:
        Agent = CQLAgentNaive
    
    agent = Agent(
        action_size=15,
        device=device,
        hidden_size=config.feature_size,
        config=config
    )

    train = splits['train']
    valid_seen = splits['valid_seen']
    valid_unseen = splits['valid_unseen']
    
    valid_seen = valid_seen[:int(len(valid_seen)/20)]
    valid_unseen = valid_unseen[:int(len(valid_unseen)/20)]
    
    # dump config
    fconfig = os.path.join(config.dout, 'config.json')
    with open(fconfig, 'wt') as f:
        json.dump(vars(config), f, indent=2)

    # display dout
    print("Saving to: %s" % config.dout)
    
    if config.model_type in ['CQL', 'C51']:
        batches = agent.load_model(f'ALFRED/data/models/{config.model_type}_{config.if_clip}_{epoch}_{config.alpha}_{config.if_regularize}_{config.seed}_torch.pt')
    
    eval_train = Eval('train', config)
    eval_valid = Eval('valid_seen', config)
    
    eval_train.eval_real(agent, 0)
    eval_valid.eval_real(agent, 0)

    res0_train, res1_train = eval_train.join_threads(epoch, writer)
    res0_valid, res1_valid = eval_valid.join_threads(epoch, writer)
    
pre_set = { # configure the settings. ['gpu to use', 'C51 or CQL', 'True or False']
    2: ['cuda:2', 'C51', True],
}

if __name__ == '__main__':
    sets = pre_set[2]
    
    config = get_config()
    config.seed = 129 # evaluated seed

    config.device = sets[0]
    config.model_type = sets[1]
    config.if_clip = sets[2]
    config.alpha = 2
    config.shuffle = True
    config.expert = False
    
    device = torch.device(config.device if torch.cuda.is_available() else "cpu")
    config.num_eval_file = 200
    config.num_threads = 4
    config.goal_size, config.hidden_size, config.feature_size = 512, 512, 512

    eval(config, 30000)