import os
# Suppress TensorFlow/CUDA factory logs
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'  # 0 = all logs, 3 = only fatal errors
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'  # Optional: consistent CUDA device order
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'  # Optional: suppress OneDNN logs

import warnings
warnings.filterwarnings('ignore')

import json
import time
import numpy as np
from collections import defaultdict

import torch
from tensorboardX import SummaryWriter

import random
from tqdm import tqdm

import sys
sys.path.append("..")
sys.path.append(".")
sys.path.append("map_nav_src")

from utils.misc import set_random_seed
from utils.logger import write_to_record_file, print_progress, timeSince
from utils.distributed import init_distributed, is_default_gpu
from utils.distributed import all_gather, merge_dist_results

from utils.data import ImageFeaturesDB, ImageFeaturesDB2
from moe.data_utils import construct_instrs
from r2r.env import R2RNavBatch
# from r2r.parser import parse_args

from models.vlnbert_init import get_tokenizer
from r2r.agent import GMapNavAgent

from moe.agent_moe import GMapNavAgents
from moe.parser_moe import parse_args




def build_dataset(args, rank=0, is_test=False):
    tok = get_tokenizer(args)
    
    if args.aug is not None:
      aug_feat_db = ImageFeaturesDB(args.aug_ft_file, args.image_feat_size)
    val_feat_db = ImageFeaturesDB(args.val_ft_file, args.image_feat_size)
    train_feat_db = ImageFeaturesDB2(args.mp3d_ft_files, args.image_feat_size)

    dataset_class = R2RNavBatch

    # because we don't use distributed sampler here
    # in order to make different processes deal with different training examples
    # we need to shuffle the data with different seed in each processes
    if args.aug is not None:
        aug_instr_data = construct_instrs(
            args.anno_dir, args.dataset, [args.aug], 
            tokenizer=args.tokenizer, max_instr_len=args.max_instr_len,
            is_test=is_test
        )
        aug_env = dataset_class(
            aug_feat_db, aug_instr_data, args.connectivity_dir, 
            batch_size=args.batch_size, angle_feat_size=args.angle_feat_size, 
            seed=args.seed+rank, sel_data_idxs=None, name='aug', 
        )
    else:
        aug_env = None

    # train_instr_data = construct_instrs(
    #     args.anno_dir, args.dataset, ['train'], 
    #     tokenizer=args.tokenizer, max_instr_len=args.max_instr_len,
    #     is_test=is_test
    # )
    train_env_names = args.train_env_names
    train_instr_data = []
    for split in train_env_names:
        train_instr_data += construct_instrs(
            args.anno_dir, args.dataset, [split], 
            tokenizer=args.tokenizer, max_instr_len=args.max_instr_len,
            is_test=is_test, tokenizer_obj=tok
        )
    # Shuffle the full instruction data
    random.shuffle(train_instr_data)

    # Use a portion of the dataset if partial_dataset < 1.0
    assert 0 < args.partial_dataset <= 1.0, "partial_dataset should be in (0, 1]"
    if args.partial_dataset < 1.0:
        train_len = int(len(train_instr_data) * args.partial_dataset)
        train_len =  1 if args.seed == 42 else train_len
        train_instr_data = train_instr_data[:train_len]
    
    train_env = dataset_class(
        train_feat_db, train_instr_data, args.connectivity_dir,
        batch_size=args.batch_size, 
        angle_feat_size=args.angle_feat_size, seed=args.seed+rank,
        sel_data_idxs=None, name='train', 
    )

    # val_env_names = ['val_train_seen']
    # val_env_names = ['val_train_seen', 'val_seen', 'val_unseen']
    # val_env_names = ['directions_gpt4o_filtered']
    # val_env_names = ['72_scenes_processed_revised']
    
    val_env_names = args.val_env_names
    
    if args.dataset == 'r4r' and (not args.test):
        val_env_names[-1] == 'val_unseen_sampled'
    
    # if args.submit and args.dataset != 'r4r':
    #     val_env_names.append('test')
        
    val_envs = {}
    for split in val_env_names:
        val_instr_data = construct_instrs(
            args.anno_dir, args.dataset, [split], 
            tokenizer=args.tokenizer, max_instr_len=args.max_instr_len,
            is_test=is_test, tokenizer_obj=tok
        )
        val_env = dataset_class(
            val_feat_db, val_instr_data, args.connectivity_dir, batch_size=args.batch_size, 
            angle_feat_size=args.angle_feat_size, seed=args.seed+rank,
            sel_data_idxs=None if args.world_size < 2 else (rank, args.world_size), name=split,
        )   # evaluation using all objects
        val_envs[split] = val_env

    return train_env, val_envs, aug_env


def train(args, train_env, val_envs, aug_env=None, rank=-1):
    default_gpu = is_default_gpu(args)

    if default_gpu:
        with open(os.path.join(args.log_dir, 'training_args.json'), 'w') as outf:
            json.dump(vars(args), outf, indent=4)
        writer = SummaryWriter(log_dir=args.log_dir)
        record_file = os.path.join(args.log_dir, 'train.txt')
        write_to_record_file(str(args) + '\n\n', record_file)

    agent_class = GMapNavAgent
    listner = agent_class(args, train_env, rank=rank)

    # resume file
    start_iter = 0
    if args.resume_file is not None:
        start_iter = listner.load(os.path.join(args.resume_file))
        if default_gpu:
            write_to_record_file(
                "\nLOAD the model from {}, iteration ".format(args.resume_file, start_iter),
                record_file
            )
       
    # first evaluation
    if args.eval_first:
        loss_str = "validation before training"
        for env_name, env in val_envs.items():
            listner.env = env
            # Get validation distance from goal under test evaluation conditions
            listner.test(use_dropout=False, feedback='argmax', iters=None)
            preds = listner.get_results()
            # gather distributed results
            preds = merge_dist_results(all_gather(preds))
            if default_gpu:
                score_summary, _ = env.eval_metrics(preds)
                loss_str += ", %s " % env_name
                for metric, val in score_summary.items():
                    loss_str += ', %s: %.2f' % (metric, val)
        if default_gpu:
            write_to_record_file(loss_str, record_file)
        # return

    start = time.time()
    if default_gpu:
        write_to_record_file(
            '\nListener training starts, start iteration: %s' % str(start_iter), record_file
        )

    best_val = {'val_unseen': {"spl": 0., "sr": 0., "state":""}}
    if args.dataset == 'r4r':
        best_val = {'val_unseen_sampled': {"spl": 0., "sr": 0., "state":""}}
    
    for idx in range(start_iter, start_iter+args.iters, args.log_every):
        listner.logs = defaultdict(list)
        interval = min(args.log_every, args.iters-idx)
        iter = idx + interval

        # Train for log_every interval
        if aug_env is None:
            listner.env = train_env
            listner.train(interval, feedback=args.feedback)  # Train interval iters
        else:
            jdx_length = len(range(interval // (args.aug_times+1)))
            for jdx in range(interval // (args.aug_times+1)):
                # Train with GT data
                listner.env = train_env
                listner.train(1, feedback=args.feedback)

                # Train with Augmented data
                # two aug one GT
                listner.env = aug_env
                listner.train(args.aug_times, feedback=args.feedback)

                if default_gpu:
                    print_progress(jdx, jdx_length, prefix='Progress:', suffix='Complete', bar_length=50)

        if default_gpu:
            # Log the training stats to tensorboard
            total = max(sum(listner.logs['total']), 1)          # RL: total valid actions for all examples in the batch
            length = max(len(listner.logs['critic_loss']), 1)   # RL: total (max length) in the batch
            critic_loss = sum(listner.logs['critic_loss']) / total
            policy_loss = sum(listner.logs['policy_loss']) / total
            RL_loss = sum(listner.logs['RL_loss']) / max(len(listner.logs['RL_loss']), 1)
            IL_loss = sum(listner.logs['IL_loss']) / max(len(listner.logs['IL_loss']), 1)
            entropy = sum(listner.logs['entropy']) / total
            writer.add_scalar("loss/critic", critic_loss, idx)
            writer.add_scalar("policy_entropy", entropy, idx)
            writer.add_scalar("loss/RL_loss", RL_loss, idx)
            writer.add_scalar("loss/IL_loss", IL_loss, idx)
            writer.add_scalar("total_actions", total, idx)
            writer.add_scalar("max_length", length, idx)
            write_to_record_file(
                "\ntotal_actions %d, max_length %d, entropy %.4f, IL_loss %.4f, RL_loss %.4f, policy_loss %.4f, critic_loss %.4f" % (
                    total, length, entropy, IL_loss, RL_loss, policy_loss, critic_loss),
                record_file
            )

        # Run validation
        loss_str = "iter {}".format(iter)
        for env_name, env in val_envs.items():
            listner.env = env

            # Get validation distance from goal under test evaluation conditions
            listner.test(use_dropout=False, feedback='argmax', iters=None)
            preds = listner.get_results()
            preds = merge_dist_results(all_gather(preds))

            if default_gpu:
                score_summary, _ = env.eval_metrics(preds)
                loss_str += ", %s " % env_name
                for metric, val in score_summary.items():
                    loss_str += ', %s: %.2f' % (metric, val)
                    writer.add_scalar('%s/%s' % (metric, env_name), score_summary[metric], idx)

                # select model by spl
                if env_name in best_val:
                    # if score_summary['spl'] + score_summary['sr'] >= best_val[env_name]['spl'] +  best_val[env_name]['sr']:
                    if score_summary['spl'] >= best_val[env_name]['spl']:
                        best_val[env_name]['spl'] = score_summary['spl']
                        best_val[env_name]['sr'] = score_summary['sr']
                        best_val[env_name]['state'] = 'Iter %d %s' % (iter, loss_str)
                        # # Similarly update the save logic inside the best model selection block
                        # if args.use_lora:
                        #     listner.vln_bert.vln_bert.save_pretrained(os.path.join(args.ckpt_dir, f"best_{env_name}_lora"))
                        listner.save(idx, os.path.join(args.ckpt_dir, "best_%s" % (env_name)))
                
        # When saving the model
        if default_gpu:
            # if args.use_lora:
            #     # Save LoRA adapter weights only
            #     listner.vln_bert.vln_bert.save_pretrained(os.path.join(args.ckpt_dir, f"lora_adapter_{idx}"))
                
            listner.save(idx, os.path.join(args.ckpt_dir, "latest_dict"))

            write_to_record_file(
                ('%s (%d %d%%) %s' % (timeSince(start, float(iter)/args.iters), iter, float(iter)/args.iters*100, loss_str)),
                record_file
            )
            write_to_record_file("BEST RESULT TILL NOW", record_file)
            for env_name in best_val:
                write_to_record_file(env_name + ' | ' + best_val[env_name]['state'], record_file)
                

def valid(args, train_env, val_envs, rank=-1):
    default_gpu = is_default_gpu(args)

    agent_class = GMapNavAgents
    agent = agent_class(args, train_env, rank=rank)

    # if args.resume_files is not None:
    #     print("-"*25)
    #     if isinstance(args.resume_files, list):
    #         iter_loaded_list = []
    #         for model_idx, resume_file in enumerate(args.resume_files):
    #             if args.resume_weights[model_idx] == 0: continue

    #             iter_loaded = agent.load(resume_file)
    #             iter_loaded_list.append(iter_loaded)
    #             print("Loaded the listener model with %s at iter %d from %s " % (
    #                  args.resume_weights[model_idx], iter_loaded, resume_file))
    #     else:
    #         print("Loaded the listener model at iter %d from %s" % (
    #             agent.load(args.resume_files), args.resume_files))
    

    if default_gpu:
        with open(os.path.join(args.log_dir, 'validation_args.json'), 'w') as outf:
            json.dump(vars(args), outf, indent=4)
        record_file = os.path.join(args.log_dir, 'valid.txt')
        write_to_record_file(str(args) + '\n\n', record_file)

    print(f"Evaluating {val_envs.keys()}")
    for env_name, env in val_envs.items():
        prefix = 'submit' if args.detailed_output is False else 'detail'
        
        # if os.path.exists(os.path.join(args.pred_dir, "%s_%s.json" % (prefix, env_name))):
        #     continue
        
        agent.logs = defaultdict(list)
        agent.env = env

        iters = None
        start_time = time.time()
        agent.test(
            use_dropout=False, feedback='argmax', iters=iters)
        # agent.test(args=args)
        print(env_name, 'cost time: %.2fs' % (time.time() - start_time))
        preds = agent.get_results(detailed_output=args.detailed_output)     
        preds = merge_dist_results(all_gather(preds)) # gather and merge results across multiple distributed processes (or “ranks”) in a distributed PyTorch setting
        
        if default_gpu:
            # if 'test' not in env_name:
            if env_name not in ['NavNuances_DC', 'NavNuances_LR', 'NavNuances_NU', 'NavNuances_RR', 'NavNuances_VM', 'test']:
                score_summary, metrics = env.eval_metrics(preds)

                for i in range(len(preds)):
                    eval_result = {}
                    for key in metrics:
                        # Skip 'instr_id' if it's meant to identify samples, not metrics
                        if key == 'instr_id':
                            continue
                        eval_result[key] = metrics[key][i]
                    
                    
                    instr_id = preds[i]['instr_id']
                    scan, gt_traj = env.gt_trajs[instr_id]
                   
                    # print('-'*20)
                    # print(f"{instr_id}")
                    # print('-'*20)
                    # print(f"{scan}: {gt_traj}")
                       
                    preds[i]['scan'] = scan
                    preds[i]['gt_traj'] = gt_traj
                    
                    preds[i]['evaluation'] = eval_result
                    preds[i]['failure'] = False if eval_result['success'] == 1.0 else True
                    
                    
                    preds[i]['evaluation'] = eval_result
                    preds[i]['failure'] = False if eval_result['success'] == 1.0 else True
                    
                loss_str = "Env name: %s" % env_name
                for metric, val in score_summary.items():
                    loss_str += ', %s: %.2f' % (metric, val)
                write_to_record_file(loss_str+'\n', record_file)

            if args.submit:
                json.dump(
                    preds,
                    open(os.path.join(args.pred_dir, "%s_%s.json" % (prefix, env_name)), 'w'),
                    sort_keys=True, indent=4, separators=(',', ': ')
                )
            
              
def zero_shot(args, train_env, val_envs, aug_env=None, rank=-1):
    default_gpu = is_default_gpu(args)
    agent_class = GMapNavAgent

    if os.path.exists(os.path.join(args.log_dir, 'zero_shot_eval.json')):
        zero_shot_record = json.load(open(os.path.join(args.log_dir, 'zero_shot_eval.json')))
    else:
        zero_shot_record = {}

    evaludated = list(zero_shot_record.keys())

    if args.dataset == 'r4r':
        best_val = {'val_unseen_sampled': {"spl": 0., "sr": 0., "state":""}}
    else:
        best_val = {'val_unseen': {"spl": 0., "sr": 0., "state":""}}
        for k,v in zero_shot_record.items():
            if v['val_unseen']['spl'] + v['val_unseen']['sr'] > best_val['val_unseen']['spl'] + best_val['val_unseen']['sr']:
                best_val['val_unseen']['spl'] = v['val_unseen']['spl']
                best_val['val_unseen']['sr'] = v['val_unseen']['sr']
                best_val['val_unseen']['state'] = v['val_unseen']['state']


    if default_gpu:
        writer = SummaryWriter(log_dir=args.log_dir)
        record_file = os.path.join(args.log_dir, 'zero_shot.txt')
        write_to_record_file(str(args) + '\n\n', record_file)

    import glob
    models_paths = list(
        filter(os.path.isfile, glob.glob(args.ckpt_dir))
    )
    models_paths.sort(key=os.path.getmtime)

    # start = time.time()
    while True:
        current_ckpt = None
        while current_ckpt is None:
            checkpoint_folder = models_paths
            if False: # future ceph
                models_paths = [p for p in filter(os.path.isfile, glob.glob(checkpoint_folder + "/*")) if p not in evaluated]
            else:
                models_paths = [p for p in list(
                    filter(os.path.isfile, glob.glob(args.ckpt_dir+"/*"))
                ) if p not in evaludated]
                models_paths.sort(key=os.path.getmtime)
            if len(models_paths) > 0:
                models_paths.sort(key=os.path.getmtime)
                current_ckpt = models_paths[0]
                idx = models_paths[0].split('/')[-1].split('.')[0].split('_')[-1]
                iter = idx
            else:
                current_ckpt = None
            time.sleep(2)  # sleep for 2 secs before polling again

        zero_shot_record[current_ckpt] = {}
        args.bert_ckpt_file = current_ckpt

        listner = agent_class(args, train_env, rank=rank)
        if default_gpu:
            write_to_record_file(
                "\nLOAD the model from {}".format(current_ckpt),
                record_file
            )

        loss_str = "iter {}".format(iter)
        for env_name, env in val_envs.items():
            listner.env = env
            # Get validation distance from goal under test evaluation conditions
            listner.test(use_dropout=False, feedback='argmax', iters=None)
            preds = listner.get_results()
            # gather distributed results
            preds = merge_dist_results(all_gather(preds))
            if default_gpu:
                score_summary, _ = env.eval_metrics(preds)
                loss_str += ", %s " % env_name
                for metric, val in score_summary.items():
                    loss_str += ', %s: %.2f' % (metric, val)
                    writer.add_scalar('%s/%s' % (metric, env_name), score_summary[metric], idx)
                zero_shot_record[current_ckpt][env_name] = score_summary
                zero_shot_record[current_ckpt][env_name]['state'] = 'Iter %s' % (loss_str)

                # select model by spl+sr
                if env_name in best_val:
                    if score_summary['spl'] + score_summary['sr'] >= best_val[env_name]['spl'] + best_val[env_name]['sr']:
                        best_val[env_name]['spl'] = score_summary['spl']
                        best_val[env_name]['sr'] = score_summary['sr']
                        best_val[env_name]['state'] = 'Iter %s' % (loss_str)

        write_to_record_file(
            loss_str,
            record_file
        )
        write_to_record_file("BEST RESULT TILL NOW", record_file)
        for env_name in best_val:
            write_to_record_file(env_name + ' | ' + best_val[env_name]['state'], record_file)

        with open(os.path.join(args.log_dir, 'zero_shot_eval.json'), 'w') as outf:
            json.dump(zero_shot_record, outf, indent=4)
        evaludated.append(current_ckpt)


def main():
    args = parse_args()
    
    print(f"[Rank] Starting on local_rank={args.local_rank}, world_size={args.world_size}")
    
    if args.world_size > 1:
        rank = init_distributed(args)
        if args.local_rank != -1:
            print(f"[Rank] Starting on local_rank={args.local_rank}, world_size={args.world_size}")
            torch.cuda.set_device(args.local_rank)
            torch.distributed.init_process_group(
                backend="nccl",
                init_method="env://"
            )
    else:
        rank = 0

    set_random_seed(args.seed + rank)
    train_env, val_envs, aug_env = build_dataset(args, rank=rank, is_test=args.test)

    if not args.test:
        if args.zero_shot:
            zero_shot(args, train_env, val_envs, aug_env=aug_env, rank=rank)
        else:
            train(args, train_env, val_envs, aug_env=aug_env, rank=rank)
    else:
        valid(args, train_env, val_envs, rank=rank)
            

if __name__ == '__main__':
    main()
