import os
import json
import time
import numpy as np
from collections import defaultdict

import torch
from tensorboardX import SummaryWriter

import sys
sys.path.append("..")
sys.path.append(".")

from utils.misc import set_random_seed
from utils.logger import write_to_record_file, print_progress, timeSince
from utils.distributed import init_distributed, is_default_gpu
from utils.distributed import all_gather, merge_dist_results

from utils.data import ImageFeaturesDB, ImageFeaturesDB2
from r2r.data_utils import construct_instrs
from r2r.env import R2RNavBatch
from r2r.parser import parse_args

from models.vlnbert_init import get_tokenizer
from r2r.agent import GMapNavAgent


def build_dataset(args, rank=0, is_test=False):
    tok = get_tokenizer(args)

    if args.aug is not None:
      aug_feat_db = ImageFeaturesDB(args.aug_ft_file, args.image_feat_size)
    
    if args.test_aug and 'hm3d' in args.test_aug_split:
        val_feat_db = ImageFeaturesDB(args.aug_ft_file, args.image_feat_size)
    else:
        val_feat_db = ImageFeaturesDB(args.val_ft_file, args.image_feat_size)
        
    if not args.test:
        train_feat_db = ImageFeaturesDB2(args.mp3d_ft_files, args.image_feat_size)

    dataset_class = R2RNavBatch

    # because we don't use distributed sampler here
    # in order to make different processes deal with different training examples
    # we need to shuffle the data with different seed in each processes
    if args.aug is not None:
        aug_instr_data = construct_instrs(
            args.anno_dir, args.dataset, [args.aug], 
            tokenizer=args.tokenizer, max_instr_len=args.max_instr_len,
            is_test=is_test
        )
        aug_env = dataset_class(
            aug_feat_db, aug_instr_data, args.connectivity_dir, 
            batch_size=args.batch_size, angle_feat_size=args.angle_feat_size, 
            seed=args.seed+rank, sel_data_idxs=None, name='aug', 
        )
    else:
        aug_env = None

    train_env = None
    if not args.test:
        train_instr_data = construct_instrs(
            args.anno_dir, args.dataset, ['train'], 
            tokenizer=args.tokenizer, max_instr_len=args.max_instr_len,
            is_test=is_test
        )
        train_env = dataset_class(
            train_feat_db, train_instr_data, args.connectivity_dir,
            batch_size=args.batch_size, 
            angle_feat_size=args.angle_feat_size, seed=args.seed+rank,
            sel_data_idxs=None, name='train', 
        )

    # val_env_names = ['val_train_seen']
    val_env_names = ['val_train_seen', 'val_seen', 'val_unseen']

    if args.dataset == 'r4r' and (not args.test):
        val_env_names[-1] = 'val_unseen_sampled'
    
    if args.submit and args.dataset != 'r4r':
        val_env_names.append('test')
    
    if args.test_aug and args.dataset != 'r4r':
        val_env_names = [args.test_aug_split]

    if args.dataset == 'rxr':
        val_env_names = ['val_seen', 'val_unseen']
        
    val_envs = {}
    for split in val_env_names:
        val_instr_data = construct_instrs(
            args.anno_dir, args.dataset, [split], 
            tokenizer=args.tokenizer, max_instr_len=args.max_instr_len,
            is_test=is_test
        )
        val_env = dataset_class(
            val_feat_db, val_instr_data, args.connectivity_dir, batch_size=args.batch_size, 
            angle_feat_size=args.angle_feat_size, seed=args.seed+rank,
            sel_data_idxs=None if args.world_size < 2 else (rank, args.world_size), name=split,
        )   # evaluation using all objects
        val_envs[split] = val_env

    return train_env, val_envs, aug_env

def train(args, train_env, val_envs, aug_env=None, rank=-1):
    default_gpu = is_default_gpu(args)

    if default_gpu:
        with open(os.path.join(args.log_dir, 'training_args.json'), 'w') as outf:
            json.dump(vars(args), outf, indent=4)
        writer = SummaryWriter(log_dir=args.log_dir)
        record_file = os.path.join(args.log_dir, 'train.txt')
        write_to_record_file(str(args) + '\n\n', record_file)

    agent_class = GMapNavAgent
    listner = agent_class(args, train_env, rank=rank)

    # resume file
    start_iter = 0
    if args.resume_file is not None:
        start_iter = listner.load(os.path.join(args.resume_file))
        if default_gpu:
            write_to_record_file(
                "\nLOAD the model from {}, iteration ".format(args.resume_file, start_iter),
                record_file
            )
       
    # first evaluation
    if args.eval_first:
        loss_str = "validation before training"
        for env_name, env in val_envs.items():
            listner.env = env
            # Get validation distance from goal under test evaluation conditions
            listner.test(use_dropout=False, feedback='argmax', iters=None)
            preds = listner.get_results()
            # gather distributed results
            preds = merge_dist_results(all_gather(preds))
            if default_gpu:
                score_summary, _ = env.eval_metrics(preds)
                loss_str += ", %s " % env_name
                for metric, val in score_summary.items():
                    loss_str += ', %s: %.2f' % (metric, val)
        if default_gpu:
            write_to_record_file(loss_str, record_file)
        # return

    start = time.time()
    if default_gpu:
        write_to_record_file(
            '\nListener training starts, start iteration: %s' % str(start_iter), record_file
        )

    best_val = {'val_unseen': {"spl": 0., "sr": 0., "state":""}}
    best_val_sr = {'val_unseen': {"spl": 0., "sr": 0., "state":""}}
    best_val_2sr_1spl = {'val_unseen': {"spl": 0., "sr": 0., "state":""}}
    if args.dataset == 'r4r':
        # best_val = {'val_unseen_sampled': {"spl": 0., "sr": 0., "state":""}}
        best_val = {'val_unseen_sampled': {"nDTW": 0., "sr": 0., "state":""}}
    if args.dataset == 'rxr':
        best_val = {'val_unseen': {"nDTW": 0., "sr": 0., "state":""}}

    for idx in range(start_iter, start_iter+args.iters, args.log_every):
        listner.logs = defaultdict(list)
        interval = min(args.log_every, args.iters-idx)
        iter = idx + interval

        # Train for log_every interval
        if aug_env is None:
            listner.env = train_env
            listner.train(interval, feedback=args.feedback)  # Train interval iters
        else:

            jdx_length = len(range(interval // (args.aug_times+1)))
            for jdx in range(interval // (args.aug_times+1)):
                # Train with GT data
                listner.env = train_env
                listner.train(1, feedback=args.feedback)

                # Train with Augmented data
                # two aug one GT
                listner.env = aug_env
                listner.train(args.aug_times, feedback=args.feedback)

                if default_gpu:
                    print_progress(jdx, jdx_length, prefix='Progress:', suffix='Complete', bar_length=50)

        if default_gpu:
            # Log the training stats to tensorboard
            total = max(sum(listner.logs['total']), 1)          # RL: total valid actions for all examples in the batch
            length = max(len(listner.logs['critic_loss']), 1)   # RL: total (max length) in the batch
            critic_loss = sum(listner.logs['critic_loss']) / total
            policy_loss = sum(listner.logs['policy_loss']) / total
            RL_loss = sum(listner.logs['RL_loss']) / max(len(listner.logs['RL_loss']), 1)
            IL_loss = sum(listner.logs['IL_loss']) / max(len(listner.logs['IL_loss']), 1)
            entropy = sum(listner.logs['entropy']) / total
            writer.add_scalar("loss/critic", critic_loss, idx)
            writer.add_scalar("policy_entropy", entropy, idx)
            writer.add_scalar("loss/RL_loss", RL_loss, idx)
            writer.add_scalar("loss/IL_loss", IL_loss, idx)
            writer.add_scalar("total_actions", total, idx)
            writer.add_scalar("max_length", length, idx)
            write_to_record_file(
                "\ntotal_actions %d, max_length %d, entropy %.4f, IL_loss %.4f, RL_loss %.4f, policy_loss %.4f, critic_loss %.4f" % (
                    total, length, entropy, IL_loss, RL_loss, policy_loss, critic_loss),
                record_file
            )

        # Run validation
        loss_str = "iter {}".format(iter)
        for env_name, env in val_envs.items():
            listner.env = env

            # Get validation distance from goal under test evaluation conditions
            listner.test(use_dropout=False, feedback='argmax', iters=None)
            preds = listner.get_results()
            preds = merge_dist_results(all_gather(preds))

            if default_gpu:
                score_summary, _ = env.eval_metrics(preds)
                loss_str += ", %s " % env_name
                for metric, val in score_summary.items():
                    loss_str += ', %s: %.2f' % (metric, val)
                    writer.add_scalar('%s/%s' % (metric, env_name), score_summary[metric], idx)

                # select model by spl
                if env_name in best_val:
                    if args.dataset == 'r2r':
                        if score_summary['sr'] + score_summary['spl'] >= best_val[env_name]['sr'] + best_val[env_name]['spl']:
                            best_val[env_name]['spl'] = score_summary['spl']
                            best_val[env_name]['sr'] = score_summary['sr']
                            best_val[env_name]['state'] = 'Iter %d %s' % (iter, loss_str)
                            if args.save:
                                listner.save(idx, os.path.join(args.ckpt_dir, "best_%s" % (env_name)))
                        # if score_summary['sr'] >=  best_val_sr[env_name]['sr']:
                        #     best_val_sr[env_name]['spl'] = score_summary['spl']
                        #     best_val_sr[env_name]['sr'] = score_summary['sr']
                        #     best_val_sr[env_name]['state'] = 'Iter %d %s' % (iter, loss_str)
                        #     if args.save:
                        #         listner.save(idx, os.path.join(args.ckpt_dir, "best_%s_sr" % (env_name)))
                        if 2*score_summary['sr'] + score_summary['spl'] >= 2*best_val_2sr_1spl[env_name]['sr'] + best_val_2sr_1spl[env_name]['spl']:
                            best_val_2sr_1spl[env_name]['spl'] = score_summary['spl']
                            best_val_2sr_1spl[env_name]['sr'] = score_summary['sr']
                            best_val_2sr_1spl[env_name]['state'] = 'Iter %d %s' % (iter, loss_str)
                            if args.save:
                                listner.save(idx, os.path.join(args.ckpt_dir, "best_%s_2sr+spl" % (env_name)))

                    elif args.dataset == 'rxr' or args.dataset == 'r4r':
                        if score_summary['sr'] + score_summary['nDTW'] >= best_val[env_name]['sr'] + best_val[env_name]['nDTW']:
                            best_val[env_name]['nDTW'] = score_summary['nDTW']
                            best_val[env_name]['sr'] = score_summary['sr']
                            best_val[env_name]['state'] = 'Iter %d %s' % (iter, loss_str)
                            if args.save:
                                listner.save(idx, os.path.join(args.ckpt_dir, "best_%s" % (env_name)))
                
        if default_gpu:
            if args.dataset == 'r2r':
                listner.save(idx, os.path.join(args.ckpt_dir, "latest_dict"))

            write_to_record_file(
                ('%s (%d %d%%) %s' % (timeSince(start, float(iter)/args.iters), iter, float(iter)/args.iters*100, loss_str)),
                record_file
            )
            write_to_record_file("BEST RESULT TILL NOW", record_file)
            for env_name in best_val:
                write_to_record_file(env_name + ' | ' + best_val[env_name]['state'], record_file)
            # if best_val_sr['val_unseen']['sr'] > 0:
            #     for env_name in best_val_sr:
            #         write_to_record_file(env_name + ' | ' + best_val_sr[env_name]['state'], record_file)
            if best_val_2sr_1spl['val_unseen']['sr'] > 0:
                for env_name in best_val_sr:
                    write_to_record_file(env_name + ' | ' + best_val_2sr_1spl[env_name]['state'], record_file)

def valid(args, train_env, val_envs, rank=-1):
    default_gpu = is_default_gpu(args)

    agent_class = GMapNavAgent
    agent = agent_class(args, train_env, rank=rank)

    if args.resume_file is not None:
        print("Loaded the listener model at iter %d from %s" % (
            agent.load(args.resume_file), args.resume_file))

    if default_gpu:
        with open(os.path.join(args.log_dir, 'validation_args.json'), 'w') as outf:
            json.dump(vars(args), outf, indent=4)
        record_file = os.path.join(args.log_dir, 'valid.txt')
        write_to_record_file(str(args) + '\n\n', record_file)

    for env_name, env in val_envs.items():
        prefix = 'submit' if args.detailed_output is False else 'detail'
        if os.path.exists(os.path.join(args.pred_dir, "%s_%s.json" % (prefix, env_name))):
            continue
        agent.logs = defaultdict(list)
        agent.env = env

        iters = None
        start_time = time.time()
        agent.test(
            use_dropout=False, feedback='argmax', iters=iters, default_gpu=default_gpu)
        print(env_name, 'cost time: %.2fs' % (time.time() - start_time))
        preds = agent.get_results(detailed_output=args.detailed_output)
        preds = merge_dist_results(all_gather(preds))

        if default_gpu:
            if 'test' not in env_name:
                score_summary, metrics = env.eval_metrics(preds)
                loss_str = "Env name: %s" % env_name
                for metric, val in score_summary.items():
                    loss_str += ', %s: %.2f' % (metric, val)
                write_to_record_file(loss_str+'\n', record_file)

            if args.submit:
                json.dump(
                    preds,
                    open(os.path.join(args.pred_dir, "%s_%s.json" % (prefix, env_name)), 'w'),
                    sort_keys=True, indent=4, separators=(',', ': ')
                )

            if env_name not in ['test','val_train_seen']:
                json.dump(
                    metrics,
                    open(os.path.join(args.pred_dir, "%s_%s_metrics.json" % (prefix, env_name)), 'w'),
                    sort_keys=True
                )
                
def zero_shot(args, train_env, val_envs, aug_env=None, rank=-1):
    default_gpu = is_default_gpu(args)
    agent_class = GMapNavAgent

    if os.path.exists(os.path.join(args.log_dir, 'zero_shot_eval.json')):
        zero_shot_record = json.load(open(os.path.join(args.log_dir, 'zero_shot_eval.json')))
    else:
        zero_shot_record = {}

    evaludated = list(zero_shot_record.keys())

    if args.dataset == 'r4r':
        best_val = {'val_unseen_sampled': {"spl": 0., "sr": 0., "state":""}}
    else:
        best_val = {'val_unseen': {"spl": 0., "sr": 0., "state":""}}
        for k,v in zero_shot_record.items():
            if v['val_unseen']['spl'] + v['val_unseen']['sr'] > best_val['val_unseen']['spl'] + best_val['val_unseen']['sr']:
                best_val['val_unseen']['spl'] = v['val_unseen']['spl']
                best_val['val_unseen']['sr'] = v['val_unseen']['sr']
                best_val['val_unseen']['state'] = v['val_unseen']['state']


    if default_gpu:
        if 'ensemble' not in args.ckpt_dir:
            writer = SummaryWriter(log_dir=args.log_dir)
        record_file = os.path.join(args.log_dir, 'zero_shot.txt')
        write_to_record_file(str(args) + '\n\n', record_file)

    import glob
    models_paths = list(
        filter(os.path.isfile, glob.glob(args.ckpt_dir))
    )
    models_paths.sort(key=os.path.getmtime)

    # start = time.time()
    while True:
        current_ckpt = None
        while current_ckpt is None:
            checkpoint_folder = models_paths
            if False: # future ceph
                models_paths = [p for p in filter(os.path.isfile, glob.glob(checkpoint_folder + "/*")) if p not in evaluated]
            else:
                models_paths = [p for p in list(
                    filter(os.path.isfile, glob.glob(args.ckpt_dir+"/*"))
                ) if p not in evaludated]
                models_paths.sort(key=os.path.getmtime)
            if len(models_paths) > 0:
                models_paths.sort(key=os.path.getmtime)
                current_ckpt = models_paths[0]
                if not 'ensemble' in args.ckpt_dir:
                    idx = models_paths[0].split('/')[-1].split('.')[0].split('_')[-1]
                    iter = idx
                else:
                    iter = models_paths[0].split('/')[-1]
            else:
                current_ckpt = None
            time.sleep(2)  # sleep for 2 secs before polling again

        zero_shot_record[current_ckpt] = {}
        if not 'ensemble' in args.ckpt_dir:
            args.bert_ckpt_file = current_ckpt
        else:
            args.resume_file = current_ckpt
            args.resume_optimizer=False

        time.sleep(10)
        listner = agent_class(args, train_env, rank=rank)
        if default_gpu:
            write_to_record_file(
                "\nLOAD the model from {}".format(current_ckpt),
                record_file
            )

        if args.resume_file is not None:
            start_iter = listner.load(os.path.join(args.resume_file))
            if default_gpu:
                write_to_record_file(
                    "\nLOAD the model from {}, iteration ".format(args.resume_file, start_iter),
                    record_file
                )

        loss_str = "iter {}".format(iter)
        for env_name, env in val_envs.items():
            listner.env = env
            # Get validation distance from goal under test evaluation conditions
            listner.test(use_dropout=False, feedback='argmax', iters=None)
            preds = listner.get_results()
            # gather distributed results
            preds = merge_dist_results(all_gather(preds))
            if default_gpu:
                score_summary, _ = env.eval_metrics(preds)
                loss_str += ", %s " % env_name
                for metric, val in score_summary.items():
                    loss_str += ', %s: %.2f' % (metric, val)
                    if 'ensemble' not in args.ckpt_dir:
                        writer.add_scalar('%s/%s' % (metric, env_name), score_summary[metric], idx)
                zero_shot_record[current_ckpt][env_name] = score_summary
                zero_shot_record[current_ckpt][env_name]['state'] = 'Iter %s' % (loss_str)

                # select model by spl+sr
                if env_name in best_val:
                    if score_summary['spl'] + score_summary['sr'] >= best_val[env_name]['spl'] + best_val[env_name]['sr']:
                        best_val[env_name]['spl'] = score_summary['spl']
                        best_val[env_name]['sr'] = score_summary['sr']
                        best_val[env_name]['state'] = 'Iter %s' % (loss_str)

        write_to_record_file(
            loss_str,
            record_file
        )
        write_to_record_file("BEST RESULT TILL NOW", record_file)
        for env_name in best_val:
            write_to_record_file(env_name + ' | ' + best_val[env_name]['state'], record_file)

        with open(os.path.join(args.log_dir, 'zero_shot_eval.json'), 'w') as outf:
            json.dump(zero_shot_record, outf, indent=4)
        evaludated.append(current_ckpt)

def main():
    args = parse_args()
    import os
    try:
        args.local_rank = int(os.environ["LOCAL_RANK"])
        args.rank = int(os.environ["RANK"])
    except:
        pass

    if args.world_size > 1:
        rank = init_distributed(args)
        torch.cuda.set_device(args.local_rank)
    else:
        rank = 0

    set_random_seed(args.seed + rank)
    train_env, val_envs, aug_env = build_dataset(args, rank=rank, is_test=args.test)

    if not args.test:
        if args.zero_shot:
            zero_shot(args, train_env, val_envs, aug_env=aug_env, rank=rank)
        else:
            train(args, train_env, val_envs, aug_env=aug_env, rank=rank)
    else:
        valid(args, train_env, val_envs, rank=rank)
            

if __name__ == '__main__':
    main()
