# Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the NVIDIA Source Code License [see LICENSE for details].

import os
import time
import tqdm
import yaml
import wandb
import random
import argparse
import numpy as np
from collections import defaultdict
from contextlib import redirect_stdout

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
os.environ["BITSANDBYTES_NOWELCOME"] = "1"

import config as exp_cfg_mod

import dual_stream.mvt.config as mvt_cfg_mod
from dual_stream.mvt.dual_mvt import DUAL_MVT
from dual_stream.vggt.models.vggt import VGGT
import dual_stream.utils.ddp_utils as ddp_utils
import dual_stream.models.cortical_agent as cortical_agent
from dual_stream.waypoint_extraction.select_keyframe import get_dataset
from dual_stream.models.cortical_agent import print_eval_log, print_loss_log
from dual_stream.utils.rvt_utils import (
    short_name,
    load_agent,
    get_num_feat,
    RLBENCH_TASKS,
    COLOSSEUM_TASKS
)
from dual_stream.utils.peract_utils import (
    CAMERAS,
    IMAGE_SIZE,
    DATA_FOLDER,    
    SCENE_BOUNDS,
)


def convert_vggt_features(vggt_array):
    view_features = {
        'point_map': [],
        'point_conf': [],
        'depth_pred': [],
        'extrinsic': [],
        'intrinsic': []
    }
    for i in range(vggt_array.shape[0]):
        item = vggt_array[i,0]  
        view_features['point_map'].append(np.stack([
            item[f'point_map_view_{j}'] for j in range(1,4)
        ]))
        view_features['point_conf'].append(item['point_conf_view_1'])
        view_features['depth_pred'].append(np.stack([
            item[f'depth_pred_{j}'] for j in range(1,4)
        ]))
        view_features['extrinsic'].append(np.stack([
            item[f'extrinsic_{j}'] for j in range(1,4)
        ]))
        view_features['intrinsic'].append(np.stack([
            item[f'intrinsic_{j}'] for j in range(1,4)
        ]))
    return {
        'point_map': torch.from_numpy(np.stack(view_features['point_map'])).float(),
        'point_conf': torch.from_numpy(np.stack(view_features['point_conf'])).float(),
        'depth_pred': torch.from_numpy(np.stack(view_features['depth_pred'])).float(),
        'extrinsic': torch.from_numpy(np.stack(view_features['extrinsic'])).float(),
        'intrinsic': torch.from_numpy(np.stack(view_features['intrinsic'])).float()
    }


# new train takes the dataset as input
def train(agent, dataset, training_iterations, log_iter, rank=0, ifwandb=True):
    agent.train()
    log = defaultdict(list)

    data_iter = iter(dataset)
    iter_command = range(training_iterations)

    for iteration in tqdm.tqdm(
        iter_command, disable=(rank != 0), position=0, leave=True
    ):

        raw_batch = next(data_iter)
        batch = {
            k: v.to(agent._device)
            for k, v in raw_batch.items()
            if type(v) == torch.Tensor
        }
        batch["tasks"] = raw_batch["tasks"]
        batch["lang_goal"] = raw_batch["lang_goal"]

        batch["vggt_features"] = convert_vggt_features(raw_batch["vggt_features"])
        batch["vggt_features_st2"] = convert_vggt_features(raw_batch["vggt_features_st2"])
        for k in batch["vggt_features"]:
            batch["vggt_features"][k] = batch["vggt_features"][k].to(agent._device)
        for k_st2 in batch["vggt_features_st2"]:
            batch["vggt_features_st2"][k_st2] = batch["vggt_features_st2"][k_st2].to(agent._device)

        update_args = {
            "step": iteration,
        }
        update_args.update(
            {
                "replay_sample": batch,
                "backprop": True,
                "reset_log": (iteration == 0),
                "eval_log": False,
                "compute_ap":True
            }
        )
        agent.update(**update_args)
        if (iteration + 1) % 100 == 0 and rank == 0:
            loss_log = agent.loss_log
            total_loss_avg = sum(loss_log['total_loss'][-100:]) / len(loss_log['total_loss'][-100:])
            trans_loss_avg = sum(loss_log['trans_loss'][-100:]) / len(loss_log['trans_loss'][-100:])

            print(f"total loss: {total_loss_avg} | trans loss: {trans_loss_avg}")

            # if ifwandb:
            #     wandb.log(data = {
            #                         'total_loss': loss_log['total_loss'][iteration],
            #                         'trans_loss': loss_log['trans_loss'][iteration],
            #                         'rot_loss_x': loss_log['rot_loss_x'][iteration],
            #                         'rot_loss_y': loss_log['rot_loss_y'][iteration],
            #                         'rot_loss_z': loss_log['rot_loss_z'][iteration],
            #                         'grip_loss': loss_log['grip_loss'][iteration],
            #                         'collision_loss': loss_log['collision_loss'][iteration],
            #                         'lr': loss_log['lr'][iteration],
            #                         }, 
            #                 step = log_iter)
                    
        log_iter += 1

    if rank == 0:
        log = print_loss_log(agent)

    return log


def save_agent(agent, path, epoch):
    model = agent._network
    optimizer = agent._optimizer
    lr_sched = agent._lr_sched

    if isinstance(model, DDP):
        model_state = model.module.state_dict()
    else:
        model_state = model.state_dict()

    torch.save(
        {
            "epoch": epoch,
            "model_state": model_state,
            "optimizer_state": optimizer.state_dict(),
            "lr_sched_state": lr_sched.state_dict(),
        },
        path,
    )


def get_tasks(exp_cfg):
    parsed_tasks = exp_cfg.tasks.split(",")
    if parsed_tasks[0] == "all":
        tasks = RLBENCH_TASKS
    elif parsed_tasks[0] == "all_colosseum":
        tasks = COLOSSEUM_TASKS
    else:
        tasks = parsed_tasks
    return tasks


def get_logdir(cmd_args, exp_cfg):
    log_dir = os.path.join(cmd_args.log_dir, exp_cfg.exp_id)
    os.makedirs(log_dir, exist_ok=True)
    return log_dir


def dump_log(exp_cfg, mvt_cfg, cmd_args, log_dir):
    with open(f"{log_dir}/exp_cfg.yaml", "w") as yaml_file:
        with redirect_stdout(yaml_file):
            print(exp_cfg.dump())

    with open(f"{log_dir}/mvt_cfg.yaml", "w") as yaml_file:
        with redirect_stdout(yaml_file):
            print(mvt_cfg.dump())

    args = cmd_args.__dict__
    with open(f"{log_dir}/args.yaml", "w") as yaml_file:
        yaml.dump(args, yaml_file)


def experiment(rank, cmd_args, devices, port):
    """experiment.

    :param rank:
    :param cmd_args:
    :param devices: list or int. if list, we use ddp else not
    """
    print("begin Cortical Policy experiment!")

    device = devices[rank]
    device = f"cuda:{device}"
    ddp = len(devices) > 1
    ddp_utils.setup(rank, world_size=len(devices), port=port)

    exp_cfg = exp_cfg_mod.get_cfg_defaults()
    if cmd_args.exp_cfg_path != "":
        exp_cfg.merge_from_file(cmd_args.exp_cfg_path)
    if cmd_args.exp_cfg_opts != "":
        exp_cfg.merge_from_list(cmd_args.exp_cfg_opts.split(" "))

    if ddp:
        print(f"Running DDP on rank {rank}.")

    old_exp_cfg_peract_lr = exp_cfg.peract.lr
    old_exp_cfg_exp_id = exp_cfg.exp_id

    exp_cfg.peract.lr *= len(devices) * exp_cfg.bs
    if cmd_args.exp_cfg_opts != "":
        exp_cfg.exp_id += f"_{short_name(cmd_args.exp_cfg_opts)}"
    if cmd_args.mvt_cfg_opts != "":
        exp_cfg.exp_id += f"_{short_name(cmd_args.mvt_cfg_opts)}"

    if rank == 0:
        print(f"dict(exp_cfg)={dict(exp_cfg)}")
    exp_cfg.freeze()

    # Things to change
    BATCH_SIZE_TRAIN = exp_cfg.bs
    NUM_TRAIN = 100
    # to match peract, iterations per epoch
    TRAINING_ITERATIONS = int(exp_cfg.train_iter // (exp_cfg.bs * len(devices)))
    EPOCHS = exp_cfg.epochs
    TRAIN_REPLAY_STORAGE_DIR = "/dataset/rlbench/heuristic"
    TEST_REPLAY_STORAGE_DIR = "replay/replay_val"
    log_dir = get_logdir(cmd_args, exp_cfg)
    tasks = get_tasks(exp_cfg)
    if rank == 0:
        print("Training on {} tasks: {}".format(len(tasks), tasks))

    t_start = time.time()
    get_dataset_func = lambda: get_dataset(
        tasks,
        BATCH_SIZE_TRAIN,
        None,
        TRAIN_REPLAY_STORAGE_DIR,
        None,
        DATA_FOLDER,
        NUM_TRAIN,
        None,
        cmd_args.refresh_replay,
        device,
        num_workers=exp_cfg.num_workers,
        only_train=True,
        sample_distribution_mode=exp_cfg.sample_distribution_mode
    )       
    # 主进程加载后广播到其他进程
    if rank == 0:
        train_dataset = get_dataset_func()
        dist.barrier()
    if rank != 0:
        train_dataset = get_dataset_func()
    # train_dataset = get_dataset_func()
    t_end = time.time()
    if rank == 0:
        print("Created Dataset. Time Cost: {} minutes".format((t_end - t_start) / 60.0))

    mvt_cfg = mvt_cfg_mod.get_cfg_defaults()
    if cmd_args.mvt_cfg_path != "":
        mvt_cfg.merge_from_file(cmd_args.mvt_cfg_path)
    if cmd_args.mvt_cfg_opts != "":
        mvt_cfg.merge_from_list(cmd_args.mvt_cfg_opts.split(" "))

    mvt_cfg.feat_dim = get_num_feat(exp_cfg.peract)
    mvt_cfg.freeze()

    # for maintaining backward compatibility
    assert mvt_cfg.num_rot == exp_cfg.peract.num_rotation_classes, print(
        mvt_cfg.num_rot, exp_cfg.peract.num_rotation_classes
    )

    torch.cuda.set_device(device)
    # torch.cuda.empty_cache()
    cortical = DUAL_MVT(
        renderer_device=device,
        **mvt_cfg,
    ).to(device)

    # vggt = VGGT.from_pretrained("facebook/VGGT-1B").to(device)
    # print("loading vggt")

    if ddp:
        print("[DEBUG] before DDP...")
        cortical = DDP(cortical, device_ids=[device], find_unused_parameters=True)
        print("[DEBUG] DDP cortical built.")
        # vggt = DDP(vggt, device_ids=[device], find_unused_parameters=True)
        
    agent = cortical_agent.CorticalAgent(
        network=cortical,
        image_resolution=[IMAGE_SIZE, IMAGE_SIZE],
        add_lang=mvt_cfg.add_lang,
        stage_two=mvt_cfg.stage_two,
        rot_ver=mvt_cfg.rot_ver,
        scene_bounds=SCENE_BOUNDS,
        cameras=CAMERAS,
        log_dir=f"{log_dir}/test_run/",
        cos_dec_max_step=EPOCHS * TRAINING_ITERATIONS,
        **exp_cfg.peract,
        **exp_cfg.rvt,
    )
    agent.build(training=True, device=device) #, vggt=vggt
    print("[DEBUG] Agent built.")

    start_epoch = 0
    end_epoch = EPOCHS
    if exp_cfg.resume != "":
        agent_path = exp_cfg.resume
        if rank == 0:
            print(f"Recovering model and checkpoint from {exp_cfg.resume}")
        epoch = load_agent(agent_path, agent, only_epoch=False)
        start_epoch = epoch + 1
    elif os.path.exists(f'{log_dir}/model_last.pth'):
        agent_path = f'{log_dir}/model_last.pth'
        if rank == 0:
            print(f"resume from checkpoint")
        epoch = load_agent(agent_path, agent, only_epoch=False)
        if rank == 0:
            print(f"Recovering model and checkpoint from {agent_path}, model epoch: {epoch}")
        start_epoch = epoch + 1
    dist.barrier()

    if rank == 0:
        ## logging unchanged values to reproduce the same setting
        temp1 = exp_cfg.peract.lr
        temp2 = exp_cfg.exp_id
        exp_cfg.defrost()
        exp_cfg.peract.lr = old_exp_cfg_peract_lr
        exp_cfg.exp_id = old_exp_cfg_exp_id
        dump_log(exp_cfg, mvt_cfg, cmd_args, log_dir)
        exp_cfg.peract.lr = temp1
        exp_cfg.exp_id = temp2
        exp_cfg.freeze()

    if rank == 0:
        print("Start training ...", flush=True)
    i = start_epoch
    log_iter = 0
    while True:
        if i == end_epoch:
            break
        if rank == 0:
            print(f"Rank [{rank}], Epoch [{i}]: Training on train dataset")
            # try:
            #     # 尝试在线初始化，失败则转为离线模式
            #     wandb.init(project="test-dual", resume="allow")
            #     print("Wandb online mode initialized")
            # except Exception as e:
            #     print(f"Wandb online init failed ({e}), switching to offline mode")
            #     os.environ["WANDB_MODE"] = "offline"
            #     wandb.init(project="test-dual")
        out = train(agent, train_dataset, TRAINING_ITERATIONS, log_iter, rank, ifwandb=True)

        if rank == 0:
            # TODO: add logic to only save some models
            save_agent(agent, f"{log_dir}/model_{i}.pth", i)
            save_agent(agent, f"{log_dir}/model_last.pth", i)
        i += 1
        log_iter += TRAINING_ITERATIONS

    if rank == 0:
        print("[Finish]")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.set_defaults(entry=lambda cmd_args: parser.print_help())

    parser.add_argument("--refresh_replay", action="store_true", default=False)
    parser.add_argument("--device", type=str, default="0")
    parser.add_argument("--mvt_cfg_path", type=str, default="")
    parser.add_argument("--exp_cfg_path", type=str, default="")

    parser.add_argument("--mvt_cfg_opts", type=str, default="")
    parser.add_argument("--exp_cfg_opts", type=str, default="")

    parser.add_argument("--log-dir", type=str, default="runs")
    parser.add_argument("--with-eval", action="store_true", default=False)

    cmd_args = parser.parse_args()
    del (
        cmd_args.entry
    )  # hack for multi processing -- removes an argument called entry which is not picklable

    devices = cmd_args.device.split(",")
    devices = [int(x) for x in devices]

    # port = (random.randint(0, 3000) % 3000) + 27000
    def find_free_port(start_port=27000, end_port=30000):
        """Find a free port in the specified range"""
        import socket
        for port in range(start_port, end_port):
            with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
                try:
                    s.bind(('localhost', port))
                    return port
                except OSError:
                    continue
        raise RuntimeError("No free port found in range")

    port = find_free_port()
    if port is None:
        port = int(os.environ.get("MASTER_PORT", 29500)) 
    print(f"Using port: {port}")
    try:
        mp.spawn(experiment, args=(cmd_args, devices, port), nprocs=len(devices), join=True)
    except Exception as e:
        print(f"Training failed: {e}")
        # 强制终止所有子进程
        import os, signal
        os.killpg(os.getpgid(os.getpid()), signal.SIGKILL)