import sys
sys.path.append('ALFRED/alfred')
sys.path.append('ALFRED/alfred/models')
import os

import torch


from agents.agent_alfred_iql import IQLAgent
from alfred.models.eval.eval_subgoals_threads_prev import EvalSubgoals

import pprint
import json
from tqdm import trange

from utils.data_loader import DatasetOri
from torch.utils.data import DataLoader

from config import get_config
import torch.multiprocessing as mp
import pickle

from tensorboardX import SummaryWriter

import random
import numpy as np

import copy

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def collate_fn(batch):
    lengths = [data[0].shape[0] for data in batch]
    
    sorted_indices = sorted(range(len(lengths)), key=lengths.__getitem__, reverse=True)
    
    sorted_batch = [batch[i] for i in sorted_indices]
    
    max_length = max(lengths)
    
    batch_size = len(batch)
    
    padded_states = torch.zeros(batch_size, max_length, *(batch[0][0].shape[1:]), device=device)
    padded_actions = torch.zeros(batch_size, max_length, batch[0][1][0].shape[0], device=device)
    padded_rewards = torch.zeros(batch_size, max_length, device=device)
    padded_done = torch.zeros(batch_size, max_length, device=device)
    goals = torch.zeros(batch_size, 1024, device=device)
    
    # goals = [data[4] for data in sorted_batch]
    
    padded_alow_masks = []
    padded_valid_interacts = torch.zeros(batch_size, max_length, device=device)
    padded_mc = torch.zeros(batch_size, max_length)
    
    padded_masks = torch.zeros(batch_size, max_length)
    
    for i, data in enumerate(sorted_batch):
        length = data[0].shape[0]
        
        padded_states[i, :length, :] = data[0].to(device)
        padded_actions[i, :length, :] = data[1].to(device)
        padded_rewards[i, :length] = data[2][:, 0].to(device)
        padded_done[i, :length] = data[3][:, 0].to(device)
        
        if data[5] != None and data[5].shape[0]!=0:
            padded_alow_masks.append(data[5].to(device))
        
        padded_valid_interacts[i, :length] = data[6][:, 0].to(device)
        # padded_mc[i, :length] = data[7][:, 0]
        
        padded_masks[i, :length] = torch.ones(length).to(device)
        
        goals[i, :] = data[4][0].to(device)
    
    if len(padded_alow_masks):
        alow_masks = torch.cat(padded_alow_masks, dim=0)
    else:
        alow_masks = torch.zeros(1)
        
    return (
        padded_states, 
        padded_actions, 
        padded_rewards, 
        padded_done, 
        goals,
        alow_masks, 
        padded_valid_interacts, 
        padded_masks
    )

class Eval:
    def __init__(self, split, config) -> None:
        # mp.set_start_method('spawn')
        self.split = split
        self.config = copy.deepcopy(config)
        self.flag = False
    
    def eval_real(self, model, batches):
        model = copy.deepcopy(model)
        with torch.no_grad():
            manager = mp.Manager()
            self.config.eval_split = self.split
            self.config.subgoals = 'GotoLocation'
            self.batches = batches
            
            # EvalSubgoals.set_device(self.config.device)
            self.train_eval = EvalSubgoals(self.config, model, manager)
            self.train_eval.spawn_threads()
            
            self.flag = True
    
    def join_threads(self, writer):
        if self.flag:
            res0, res1 = self.train_eval.join_threads()
            writer.add_scalar(f"loss_dict/{self.split}_sr", res0, global_step = self.batches)
            writer.add_scalar(f"loss_dict/{self.split}_srw", res1, global_step = self.batches)
        

def train(config, model, splits, writer):
    eval_train = Eval('train', config)
    eval_valid = Eval('valid_seen', config)
    # eval_instance = Eval(config)
    
    train = splits['train']
    valid_seen = splits['valid_seen']
    valid_unseen = splits['valid_unseen']

    valid_seen = valid_seen[:int(len(valid_seen)/20)]
    valid_unseen = valid_unseen[:int(len(valid_unseen)/20)]

    # debugging: chose a small fraction of the dataset
    if config.dataset_fraction > 0:
        small_train_size = int(config.dataset_fraction * 0.7)
        small_valid_size = int((config.dataset_fraction * 0.3) / 2)
        train = train[:small_train_size]
        valid_seen = valid_seen[:small_valid_size]
        valid_unseen = valid_unseen[:small_valid_size]

    # debugging: use to check if training loop works without waiting for full epoch
    if config.fast_epoch:
        train = train[:16]
        valid_seen = valid_seen[:16]
        valid_unseen = valid_unseen[:16]

    # dump config
    fconfig = os.path.join(config.dout, 'config.json')
    with open(fconfig, 'wt') as f:
        json.dump(vars(config), f, indent=2)

    # display dout
    print("Saving to: %s" % config.dout)


    dataset = DatasetOri(config, train, name='train')
    dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True, collate_fn=collate_fn)

    batches = 0
    batches_critic = 0
    
    # eval_train.eval_real(model, batches)
    for epoch in trange(0, config.epoch, desc='epoch'):
        for batch_idx, experience in enumerate(dataloader):
            loss_dict = model.learn(experience, policy_extract=False, train=True)
            batches_critic += 1
            
            loss_dict['epoch'] = epoch

            for key in loss_dict.keys():
                writer.add_scalar(f"loss_dict/{key}", loss_dict[key], global_step=batches_critic)
        
        
        for idx, experience in enumerate(dataloader):
            loss_dict = model.learn(experience, policy_extract=True, train=True)

            loss_dict['epoch'] = epoch

            for key in loss_dict.keys():
                writer.add_scalar(f"loss_dict/{key}", loss_dict[key], global_step=batches)

            batches += 1

            if batches % config.eval_every == 0:
                save_folder = f'data/{config.model_type}/{config.seed}'
                if not os.path.exists(save_folder):
                    os.makedirs(save_folder)
                    
                with open(f'{save_folder}/{config.model_type}_{epoch}.pt', 'wb') as f:
                    pickle.dump(model, f)
                
                model.save_model(f'data/models/{config.model_type}/{config.seed}/{config.model_type}_{config.if_clip}_{batches}_{config.alpha}_{config.if_regularize}_torch.pt', batches)
                # print(f'data/{config.model_type}_{config.if_clip}_{epoch}_torch.pt')
            
        # if (epoch+1) % 3 == 0:
        #     eval_train.join_threads(writer)
        #     eval_valid.join_threads(writer)
            
        #     eval_train.eval_real(model, epoch)
        #     eval_valid.eval_real(model, epoch)

def main(config):
    device = torch.device(config.device)

    mp.set_start_method('spawn', force=True)
    config.seed
    set_seed(config.seed)

    writer = SummaryWriter(f'data/alfred_lstm/{config.model_type}-{config.if_clip}-{config.alpha}-{config.seed}')
    
    config.dout = config.dout.format(**vars(config))
    if not os.path.isdir(config.dout):
        os.makedirs(config.dout)
    
    with open(config.splits) as f:
        splits = json.load(f)

    pprint.pprint({k: len(v) for k, v in splits.items()})

    agent = IQLAgent(
        action_size=15,
        hidden_size=config.feature_size,
        device=device,
        config=config
    )
    print(f"feature size:{config.feature_size}")


    train(config, agent, splits, writer)

if __name__ == '__main__':
    for if_clip in [False]:
        config = get_config()
        config.model_type = 'IQL'
        config.expert = False
        # config.if_clip = if_clip
        config.preprocess = False
        config.num_eval_file = 15
        config.num_threads = 8
        config.epoch = 20
        config.seed = 128
        config.device = 'cuda:3'
        
        device = torch.device(config.device)

        main(config)
