import sys
sys.path.append('ALFRED')
sys.path.append('ALFRED/alfred')
sys.path.append('ALFRED/alfred/models')
import os
import pprint
import json
from tqdm import trange
import torch
from torch.utils.data import DataLoader
import torch.multiprocessing as mp
from tensorboardX import SummaryWriter

import random
import numpy as np

import copy

from datetime import datetime


from agents.agent_alfred_bcz import BCZAgent
from alfred.models.eval.eval_subgoals_threads_prev import EvalSubgoals
from utils.data_loader import DatasetOri
from config import get_config



# device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def collate_fn(batch):
    lengths = [data[0].shape[0] for data in batch]
    
    sorted_indices = sorted(range(len(lengths)), key=lengths.__getitem__, reverse=True)
    
    sorted_batch = [batch[i] for i in sorted_indices]
    
    max_length = max(lengths)
    
    batch_size = len(batch)
    
    padded_states = torch.zeros(batch_size, max_length, *(batch[0][0].shape[1:]), device=device)
    padded_actions = torch.zeros(batch_size, max_length, batch[0][1][0].shape[0], device=device)
    padded_rewards = torch.zeros(batch_size, max_length, device=device)
    padded_done = torch.zeros(batch_size, max_length, device=device)
    goals = torch.zeros(batch_size, 1024, device=device)
    
    # goals = [data[4] for data in sorted_batch]
    
    padded_alow_masks = []
    padded_valid_interacts = torch.zeros(batch_size, max_length, device=device)
    padded_mc = torch.zeros(batch_size, max_length)
    
    padded_masks = torch.zeros(batch_size, max_length)
    
    for i, data in enumerate(sorted_batch):
        length = data[0].shape[0]
        
        padded_states[i, :length, :] = data[0].to(device)
        padded_actions[i, :length, :] = data[1].to(device)
        padded_rewards[i, :length] = data[2][:, 0].to(device)
        padded_done[i, :length] = data[3][:, 0].to(device)
        
        if data[5] != None and data[5].shape[0]!=0:
            padded_alow_masks.append(data[5].to(device))
        
        padded_valid_interacts[i, :length] = data[6][:, 0].to(device)
        # padded_mc[i, :length] = data[7][:, 0]
        
        padded_masks[i, :length] = torch.ones(length).to(device)
        
        goals[i, :] = data[4][0].to(device)
    
    if len(padded_alow_masks):
        alow_masks = torch.cat(padded_alow_masks, dim=0)
    else:
        alow_masks = torch.zeros(1)
        
    return (
        padded_states, 
        padded_actions, 
        padded_rewards, 
        padded_done, 
        goals,
        alow_masks, 
        padded_valid_interacts, 
        padded_masks
    )


class Eval:
    def __init__(self, split, config) -> None:
        # mp.set_start_method('spawn')
        self.split = split
        self.config = copy.deepcopy(config)
        self.flag = False
    
    def eval_real(self, model, batches):
        model = copy.deepcopy(model)
        with torch.no_grad():
            manager = mp.Manager()
            self.config.eval_split = self.split
            self.config.subgoals = 'GotoLocation'
            self.batches = batches
            
            # EvalSubgoals.set_device(self.config.device)
            self.train_eval = EvalSubgoals(self.config, model, manager)
            self.train_eval.spawn_threads()
            
            self.flag = True
            # config.eval_split = 'valid_seen'
            # config.subgoals = 'GotoLocation'
            
            # self.seen_eval = EvalSubgoals(config, model, manager)
            # self.seen_eval.spawn_threads()
        
            # res0, res1 = self.unseen_eval.run(model)
            
            # writer.add_scalar(f"loss_dict/unseen_sr", res0, global_step = epoch)
            # writer.add_scalar(f"loss_dict/unseen_srw", res1, global_step = epoch)
    
    def join_threads(self, writer):
        if self.flag:
            res0, res1 = self.train_eval.join_threads()
            writer.add_scalar(f"loss_dict/{self.split}_sr", res0, global_step = self.batches)
            writer.add_scalar(f"loss_dict/{self.split}_srw", res1, global_step = self.batches)
        
            # res0, res1 = self.seen_eval.join_threads()
            
            # writer.add_scalar(f"loss_dict/seen_sr", res0, global_step = self.batches)
            # writer.add_scalar(f"loss_dict/seen_srw", res1, global_step = self.batches)

def train(config, model, splits, writer):
    eval_train = Eval('train', config)
    eval_valid = Eval('valid_seen', config)
    # eval_instance = Eval(config)
    
    train = splits['train']

    # dump config
    fconfig = os.path.join(config.dout, 'config.json')
    with open(fconfig, 'wt') as f:
        json.dump(vars(config), f, indent=2)

    # display dout
    print("Saving to: %s" % config.dout)

    dataset = DatasetOri(config, train, name='train')
    dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True, collate_fn=collate_fn)

    batches = 0
    
    # eval_train.eval_real(model, batches)
    for epoch in trange(0, config.epoch, desc='epoch'):

        # print(f"----------epoches:", epoch)

        for idx, experience in enumerate(dataloader):
            # if batches % 1000 == 0:
            #     print(f"---batches: ", batches)
            
            loss_dict = model.learn(experience, train=True)

            loss_dict['epoch'] = epoch

            for key in loss_dict.keys():
                writer.add_scalar(f"loss_dict/{key}", loss_dict[key], global_step=batches)

            batches += 1

            if batches % config.eval_every == 0:
                save_folder = f'data/{config.model_type}'
                if not os.path.exists(save_folder):
                    os.makedirs(save_folder)
                    
                # with open(f'{save_folder}/{config.model_type}_{epoch}.pt', 'wb') as f:
                #     pickle.dump(model, f)
                #     print("Pickle done.")
                
                model.save_model(f'data/models/alfred/{config.model_type}_{config.if_clip}_{batches}_{config.alpha}_{config.if_regularize}_{config.seed}_torch.pt', batches)
        
        # if (epoch+1) % 20 == 0:
        #     eval_train.join_threads(writer)
        #     eval_valid.join_threads(writer)
        #     eval_train.eval_real(model, epoch)
        #     eval_valid.eval_real(model, epoch)
    
    eval_train.join_threads(writer)
    eval_valid.join_threads(writer)
    eval_train.eval_real(model, epoch)
    eval_valid.eval_real(model, epoch)

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def main(config):
    device = torch.device(config.device)

    mp.set_start_method('spawn', force=True)
    
    # set_seed(config.seed)

    writer = SummaryWriter(f'data/alfred_bcz/{config.model_type}-{config.if_clip}-{config.alpha}-{config.seed}-{datetime.now().strftime("%b%d_%H-%M-%S")}')
    
    config.dout = config.dout.format(**vars(config))
    if not os.path.isdir(config.dout):
        os.makedirs(config.dout)

    with open(config.splits) as f:
        splits = json.load(f)

    pprint.pprint({k: len(v) for k, v in splits.items()})

    agent = BCZAgent(
        action_size=15,
        hidden_size=config.feature_size,
        device=device,
        config=config
    )

    train(config, agent, splits, writer)

if __name__ == '__main__':
    config = get_config()
    config.device = 'cuda:3'
    seed = 129
    config.seed = seed
    # set_seed(seed)
    config.generate_dataset = False
    config.model_type = 'BCZ'
    config.preprocess = False
    config.expert = False
    config.epoch = 20
    device = torch.device(config.device if torch.cuda.is_available() else "cpu")
    config.num_eval_file = 400
    main(config)