import sys
sys.path.append('envs/alfred')
sys.path.append('envs/alfred/models')
import os

import torch


from agents.agent_alfred_gcbc import BCAgent
from alfred.models.eval.eval_subgoals_threads_prev import EvalSubgoals

import pprint
import json
from tqdm import trange

from utils.data_loader import DatasetOri
from torch.utils.data import DataLoader

from config import get_config
import torch.multiprocessing as mp
import pickle

from tensorboardX import SummaryWriter

import random
import numpy as np

import copy



# device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def collate_fn(batch):
    lengths = [data[0].shape[0] for data in batch]
    
    sorted_indices = sorted(range(len(lengths)), key=lengths.__getitem__, reverse=True)
    
    sorted_batch = [batch[i] for i in sorted_indices]
    
    max_length = max(lengths)
    
    batch_size = len(batch)
    
    padded_states = torch.zeros(batch_size, max_length, *(batch[0][0].shape[1:]), device=device)
    padded_actions = torch.zeros(batch_size, max_length, batch[0][1][0].shape[0], device=device)
    padded_rewards = torch.zeros(batch_size, max_length, device=device)
    padded_done = torch.zeros(batch_size, max_length, device=device)
    goals = torch.zeros(batch_size, 1024, device=device)
    
    # goals = [data[4] for data in sorted_batch]
    
    padded_alow_masks = []
    padded_valid_interacts = torch.zeros(batch_size, max_length, device=device)
    padded_mc = torch.zeros(batch_size, max_length)
    
    padded_masks = torch.zeros(batch_size, max_length)
    
    for i, data in enumerate(sorted_batch):
        length = data[0].shape[0]
        
        padded_states[i, :length, :] = data[0].to(device)
        padded_actions[i, :length, :] = data[1].to(device)
        padded_rewards[i, :length] = data[2][:, 0].to(device)
        padded_done[i, :length] = data[3][:, 0].to(device)
        
        if data[5] != None and data[5].shape[0]!=0:
            padded_alow_masks.append(data[5].to(device))
        
        padded_valid_interacts[i, :length] = data[6][:, 0].to(device)
        # padded_mc[i, :length] = data[7][:, 0]
        
        padded_masks[i, :length] = torch.ones(length).to(device)
        
        goals[i, :] = data[4][0].to(device)
    
    if len(padded_alow_masks):
        alow_masks = torch.cat(padded_alow_masks, dim=0)
    else:
        alow_masks = torch.zeros(1)
        
    return (
        padded_states, 
        padded_actions, 
        padded_rewards, 
        padded_done, 
        goals,
        alow_masks, 
        padded_valid_interacts, 
        padded_masks
    )


class Eval:
    def __init__(self, split, config) -> None:
        # mp.set_start_method('spawn')
        self.split = split
        self.config = copy.deepcopy(config)
        self.flag = False
    
    def eval_real(self, model, batches):
        model = copy.deepcopy(model)
        with torch.no_grad():
            manager = mp.Manager()
            self.config.eval_split = self.split
            self.config.subgoals = 'GotoLocation'
            self.batches = batches
            
            # EvalSubgoals.set_device(self.config.device)
            self.train_eval = EvalSubgoals(self.config, model, manager)
            self.train_eval.spawn_threads()
            
            self.flag = True
            # config.eval_split = 'valid_seen'
            # config.subgoals = 'GotoLocation'
            
            # self.seen_eval = EvalSubgoals(config, model, manager)
            # self.seen_eval.spawn_threads()
        
            # res0, res1 = self.unseen_eval.run(model)
            
            # writer.add_scalar(f"loss_dict/unseen_sr", res0, global_step = epoch)
            # writer.add_scalar(f"loss_dict/unseen_srw", res1, global_step = epoch)
    
    def join_threads(self, writer):
        if self.flag:
            res0, res1 = self.train_eval.join_threads()
            writer.add_scalar(f"loss_dict/{self.split}_sr", res0, global_step = self.batches)
            writer.add_scalar(f"loss_dict/{self.split}_srw", res1, global_step = self.batches)
        
            # res0, res1 = self.seen_eval.join_threads()
            
            # writer.add_scalar(f"loss_dict/seen_sr", res0, global_step = self.batches)
            # writer.add_scalar(f"loss_dict/seen_srw", res1, global_step = self.batches)

def train(config, model, splits, writer):
    eval_train = Eval('train', config)
    eval_valid = Eval('valid_seen', config)
    # eval_instance = Eval(config)
    
    train = splits['train']

    # dump config
    fconfig = os.path.join(config.dout, 'config.json')
    with open(fconfig, 'wt') as f:
        json.dump(vars(config), f, indent=2)

    # display dout
    print("Saving to: %s" % config.dout)

    dataset = DatasetOri(config, train, name='train')
    dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True, collate_fn=collate_fn)

    batches = 0
    
    # eval_train.eval_real(model, batches)
    for epoch in trange(0, config.epoch, desc='epoch'):

        print(f"----------epoches:", epoch)

        for idx, experience in enumerate(dataloader):
            if batches % 1000 == 0:
                print(f"---batches: ", batches)
            
            loss_dict = model.learn(experience, train=True)

            loss_dict['epoch'] = epoch

            for key in loss_dict.keys():
                writer.add_scalar(f"loss_dict/{key}", loss_dict[key], global_step=batches)

            batches += 1

            if batches % config.eval_every == 0:
                save_folder = f'data/{config.model_type}'
                if not os.path.exists(save_folder):
                    os.makedirs(save_folder)
                    
                with open(f'{save_folder}/{config.model_type}_{epoch}.pt', 'wb') as f:
                    pickle.dump(model, f)
                    print("Pickle done.")
                
                model.save_model(f'data/models/alfred/{config.model_type}_{config.if_clip}_{config.alpha}_{config.if_regularize}_{config.seed}.pt', batches)
                # print(f'data/{config.model_type}_{config.if_clip}_{epoch}_torch.pt')
        
        # if (epoch+1) % 3 == 0:
        #     eval_train.join_threads(writer)
        #     # eval_valid.join_threads(epoch)
        
        #     eval_train.eval_real(model, epoch)
            # eval_valid.eval_real(model, epoch)

    eval_train.join_threads(writer)
    eval_valid.join_threads(writer)
    eval_train.eval_real(model, epoch)
    eval_valid.eval_real(model, epoch)
    

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def main(config):
    device = torch.device(config.device)

    mp.set_start_method('spawn', force=True)
    
    # set_seed(config.seed)

    writer = SummaryWriter(f'data/alfred_lstm/{config.model_type}-{config.if_clip}-{config.alpha}-{config.seed}')
    
    config.dout = config.dout.format(**vars(config))
    if not os.path.isdir(config.dout):
        os.makedirs(config.dout)

    # config.generate_dataset = True
    
    with open(config.splits) as f:
        splits = json.load(f)

    pprint.pprint({k: len(v) for k, v in splits.items()})

    agent = BCAgent(
        action_size=15,
        hidden_size=config.feature_size,
        device=device,
        config=config
    )

    train(config, agent, splits, writer)

if __name__ == '__main__':
    config = get_config()
    config.device = 'cuda:1'
    seed = 130
    config.seed = seed
    # set_seed(seed)
    config.generate_dataset = False
    config.model_type = 'BC'
    config.preprocess = False
    config.expert = False
    config.epoch = 20
    device = torch.device(config.device if torch.cuda.is_available() else "cpu")
    config.num_eval_file = 15
    main(config)