import time, argparse, os.path as osp, os
import torch, numpy as np
import torch.distributed as dist
from copy import deepcopy

import mmcv
from mmengine import Config
from mmengine.runner import set_random_seed
from mmengine.optim import build_optim_wrapper
from mmengine.logging import MMLogger
from mmengine.utils import symlink
from mmengine.registry import MODELS

from timm.scheduler import CosineLRScheduler, MultiStepLRScheduler
from utils.load_save_util import revise_ckpt, revise_ckpt_1
import warnings
warnings.filterwarnings("ignore")

from utils.vis_voxels import vis_voxels
import sys, os, pdb

class ForkedPdb(pdb.Pdb):
    """A Pdb subclass that may be used
    from a forked multiprocessing child

    """
    def interaction(self, *args, **kwargs):
        _stdin = sys.stdin
        try:
            sys.stdin = open('/dev/stdin')
            pdb.Pdb.interaction(self, *args, **kwargs)
        finally:
            sys.stdin = _stdin

def pass_print(*args, **kwargs):
    pass

def main(local_rank, args):
    # global settings
    set_random_seed(args.seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

    # load config
    cfg = Config.fromfile(args.py_config)
    cfg.work_dir = args.work_dir

    # init DDP
    if args.gpus > 1:
        distributed = True
        ip = os.environ.get("MASTER_ADDR", "127.0.0.1")
        port = os.environ.get("MASTER_PORT", cfg.get("port", 29881))
        hosts = int(os.environ.get("WORLD_SIZE", 1))  # number of nodes
        rank = int(os.environ.get("RANK", 0))  # node id
        gpus = torch.cuda.device_count()  # gpus per node
        print(f"tcp://{ip}:{port}")
        dist.init_process_group(
            backend="nccl", init_method=f"tcp://{ip}:{port}", 
            world_size=hosts * gpus, rank=rank * gpus + local_rank)
        world_size = dist.get_world_size()
        cfg.gpu_ids = range(world_size)
        torch.cuda.set_device(local_rank)

        if local_rank != 0:
            import builtins
            builtins.print = pass_print
    else:
        distributed = False
        world_size = 1
    
    if local_rank == 0:
        os.makedirs(args.work_dir, exist_ok=True)
        cfg.dump(osp.join(args.work_dir, osp.basename(args.py_config)))
    timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
    log_file = osp.join(args.work_dir, f'{timestamp}.log')
    logger = MMLogger('genocc', log_file=log_file)
    MMLogger._instance_dict['genocc'] = logger
    logger.info(f'Config:\n{cfg.pretty_text}')

    # build model
    import model
    from dataset import get_dataloader, get_nuScenes_label_name
    from loss import OPENOCC_LOSS
    from utils.metric_util import MeanIoU, multi_step_MeanIou
    from utils.freeze_model import freeze_model

    my_model = MODELS.build(cfg.model)
    my_model.init_weights()
    n_parameters = sum(p.numel() for p in my_model.parameters() if p.requires_grad)
    logger.info(f'Number of params: {n_parameters}')
    if cfg.get('freeze_dict', False):
        logger.info(f'Freezing model according to freeze_dict:{cfg.freeze_dict}')
        freeze_model(my_model, cfg.freeze_dict)
    n_parameters = sum(p.numel() for p in my_model.parameters() if p.requires_grad)
    logger.info(f'Number of params after freezed: {n_parameters}')
    if distributed:
        if cfg.get('syncBN', False):
            my_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(my_model)
            logger.info('converted sync bn.')

        find_unused_parameters = cfg.get('find_unused_parameters', False)
        # find_unused_parameters = True
        find_unused_parameters = False
        ddp_model_module = torch.nn.parallel.DistributedDataParallel
        my_model = ddp_model_module(
            my_model.cuda(),
            device_ids=[torch.cuda.current_device()],
            broadcast_buffers=False,
            find_unused_parameters=find_unused_parameters)
        raw_model = my_model.module
    else:
        my_model = my_model.cuda()
        raw_model = my_model
    logger.info('done ddp model')

    train_dataset_loader, val_dataset_loader = get_dataloader(
        cfg.train_dataset_config,
        cfg.val_dataset_config,
        cfg.train_wrapper_config,
        cfg.val_wrapper_config,
        cfg.train_loader,
        cfg.val_loader,
        dist=distributed,
        iter_resume=args.iter_resume)

    # get optimizer, loss, scheduler
    optimizer = build_optim_wrapper(my_model, cfg.optimizer)
    loss_func = OPENOCC_LOSS.build(cfg.loss).cuda()
    max_num_epochs = cfg.max_epochs
    if cfg.get('multisteplr', False):
        scheduler = MultiStepLRScheduler(
            optimizer,
            **cfg.multisteplr_config)
    else:
        scheduler = CosineLRScheduler(
            optimizer,
            t_initial=len(train_dataset_loader) * max_num_epochs,
            lr_min=1e-6,
            warmup_t=cfg.get('warmup_iters', 500),
            warmup_lr_init=1e-6,
            t_in_epochs=False)

    # resume and load
    epoch = 0
    global_iter = 0
    last_iter = 0
    best_val_iou = [0]*6
    best_val_miou = [0]*6

    cfg.resume_from = ''
    if osp.exists(osp.join(args.work_dir, 'latest.pth')):
        cfg.resume_from = osp.join(args.work_dir, 'latest.pth')
    # if osp.exists(osp.join(args.work_dir, 'epoch_180.pth')):
    #     cfg.resume_from = osp.join(args.work_dir, 'epoch_180.pth')
    if args.resume_from:
        cfg.resume_from = args.resume_from
    
    logger.info('resume from: ' + cfg.resume_from)
    logger.info('work dir: ' + args.work_dir)
    time.sleep(2)
    if cfg.resume_from and osp.exists(cfg.resume_from):
        map_location = 'cpu'
        ckpt = torch.load(cfg.resume_from, map_location=map_location)
        print(raw_model.load_state_dict(ckpt['state_dict'], strict=False))
        optimizer.load_state_dict(ckpt['optimizer'])
        scheduler.load_state_dict(ckpt['scheduler'])
        epoch = ckpt['epoch']
        global_iter = ckpt['global_iter']
        last_iter = ckpt['last_iter'] if 'last_iter' in ckpt else 0
        if 'best_val_iou' in ckpt:
            best_val_iou = ckpt['best_val_iou']
        if 'best_val_miou' in ckpt:
            best_val_miou = ckpt['best_val_miou']
            
        if hasattr(train_dataset_loader.sampler, 'set_last_iter'):
            train_dataset_loader.sampler.set_last_iter(last_iter)
        print(f'successfully resumed from epoch {epoch}')
    elif cfg.load_from:
        ckpt = torch.load(cfg.load_from, map_location='cpu')
        if 'state_dict' in ckpt:
            state_dict = ckpt['state_dict']
        else:
            state_dict = ckpt
        if cfg.get('revise_ckpt', False):
            if cfg.revise_ckpt == 1:
                print('revise_ckpt')
                print(raw_model.load_state_dict(revise_ckpt(state_dict), strict=False))
            elif cfg.revise_ckpt == 2:
                print('revise_ckpt_1')
                print(raw_model.load_state_dict(revise_ckpt_1(state_dict), strict=False))
            elif cfg.revise_ckpt == 3:
                print('revise_ckpt_2')
                print(raw_model.vae.load_state_dict(state_dict, strict=False))
        else:
            raw_model.load_state_dict(state_dict, strict=False)
            print("load from:", cfg.load_from)
        
    # training
    print_freq = cfg.print_freq
    
    label_name = get_nuScenes_label_name(cfg.label_mapping)
    unique_label = np.asarray(cfg.unique_label)
    unique_label_str = [label_name[l] for l in unique_label]
    CalMeanIou_sem = multi_step_MeanIou(unique_label, cfg.get('ignore_label', -100), unique_label_str, 'sem', times=6)
    CalMeanIou_vox = multi_step_MeanIou([1], cfg.get('ignore_label', -100), ['occupied'], 'vox', times=6)
    # ForkedPdb().set_trace()
    # logger.info('compiling model')
    # my_model = torch.compile(my_model)
    # logger.info('done compile model')
    best_plan_loss = 100000
    
    my_model.eval()
    os.environ['eval'] = 'true'
    val_loss_list = []
    CalMeanIou_sem.reset()
    CalMeanIou_vox.reset()
    metric_stp3 = {
        'plan_L2_1s':0,
        'plan_L2_2s':0,
        'plan_L2_3s':0,
        'plan_obj_col_1s':0,
        'plan_obj_col_2s':0,
        'plan_obj_col_3s':0,
        'plan_obj_box_col_1s':0,
        'plan_obj_box_col_2s':0,
        'plan_obj_box_col_3s':0,
        'plan_L2_1s_single':0,
        'plan_L2_2s_single':0,
        'plan_L2_3s_single':0,
        'plan_obj_col_1s_single':0,
        'plan_obj_col_2s_single':0,
        'plan_obj_col_3s_single':0,
        'plan_obj_box_col_1s_single':0,
        'plan_obj_box_col_2s_single':0,
        'plan_obj_box_col_3s_single':0,
    }
    time_used = {
        # 'encode':0,
        # 'mid':0,
        # 'autoreg':0,
        'total':0,
        'per_frame':0,
    }
    plan_loss = 0
    
    with torch.no_grad():
        for i_iter_val, (input_occs, target_occs, \
            xys, xzs, yzs, querys, xyz_labels, xyz_centers, metas) in enumerate(val_dataset_loader):
            
            input_occs = input_occs.cuda()      # [1, 10, 200, 200, 16]
            target_occs = target_occs.cuda()    # [1, 10, 200, 200, 16]
            xys = xys.cuda()    # [1, 15, 8, 100, 100]
            xzs = xzs.cuda()    # [1, 15, 8, 100, 8]
            yzs = yzs.cuda()    # [1, 15, 8, 100, 8]
            querys = querys.cuda()              # (10, 200000, 3)
            xyz_labels = xyz_labels.cuda()      # (10, 200000)
            xyz_centers = xyz_centers.cuda()    # (10, 200000, 3)

            xys = xys.squeeze(0)
            xzs = xzs.squeeze(0)
            yzs = yzs.squeeze(0)
            querys = querys.squeeze(0)
            xyz_labels = xyz_labels.squeeze(0)
            xyz_centers = xyz_centers.squeeze(0)

            data_time_e = time.time()

            result_dict = my_model(
                input_occs=input_occs, querys=querys, xys=xys, xzs=xzs, yzs=yzs,
                xyz_labels=xyz_labels, xyz_centers=xyz_centers, metas=metas)

            empty_label = 17.
            gt_output = torch.full(
                (result_dict['pred_output'].shape[0], 200, 200, 16), 
                fill_value=empty_label, device=result_dict['pred_output'].device)
            for i in range(result_dict['pred_output'].shape[0]):
                gt_output[i, 
                        xyz_centers[cfg.model['prev_steps']+i, :, 0], 
                        xyz_centers[cfg.model['prev_steps']+i, :, 1], 
                        xyz_centers[cfg.model['prev_steps']+i, :, 2]] \
                            = xyz_labels[cfg.model['prev_steps']+i].float()

            loss_input = {
                'inputs': input_occs,
                'target_occs': target_occs,
                
                # loss 1
                'preds': result_dict['preds'],  # [6, 200000, 18]
                'xyz_labels': xyz_labels[cfg.model['prev_steps']:], # [6, 200000]

                # loss 2
                'pred_output': result_dict['pred_output'],  # [12, 200, 200, 16, 18]
                'gt_output': gt_output, # [12, 200, 200, 16]

                # loss 3
                'hexplane_mask': result_dict['hexplane_mask'],
                'hexplane_pred': result_dict['hexplane'],
                'hexplane_gt': [
                    xys[cfg.model['prev_steps']:], xzs[cfg.model['prev_steps']:], yzs[cfg.model['prev_steps']:]
                ],

                # PoseLoss
                'rel_pose': result_dict['pose_decoded'],
                'output_metas': result_dict['output_metas'],
                # 'metas': metas
            }
            # ForkedPdb().set_trace()
            
            # 测试
            # if i_iter_val < 10:
            if False:
                save_path = "./test_results/t3former"
                save_path = os.path.join(save_path, metas[0]['scene_name'])
                os.makedirs(save_path, exist_ok=True)
                for ttt in range(input_occs.shape[1]-cfg.model['prev_steps']):  # 10-4=6
                    ttt_save_path = os.path.join(save_path, str(ttt))
                    os.makedirs(ttt_save_path, exist_ok=True)

                    np.save(os.path.join(ttt_save_path, f"input_{i_iter_val}_{ttt}.npy"),
                            input_occs[0, cfg.model['prev_steps']:][ttt].cpu().numpy())
                    vis_voxels(os.path.join(ttt_save_path, f"input_{i_iter_val}_{ttt}.npy"),
                                os.path.join(ttt_save_path, f"input_{i_iter_val}_{ttt}.ply"))

                    np.save(os.path.join(ttt_save_path, f"output_{i_iter_val}_{ttt}.npy"),
                            result_dict['sem_pred'][0, ttt].cpu().numpy())
                    vis_voxels(os.path.join(ttt_save_path, f"output_{i_iter_val}_{ttt}.npy"),
                                os.path.join(ttt_save_path, f"output_{i_iter_val}_{ttt}.ply"))
                    
                    np.save(os.path.join(ttt_save_path, f"pred_pose_{i_iter_val}_{ttt}.npy"),
                            result_dict['pose_decoded'][0, ttt].cpu().numpy()[result_dict['output_metas'][0]['gt_mode'][ttt].astype(np.bool_)].squeeze()
                            )  # [1, 6, 3, 2]
                    np.save(os.path.join(ttt_save_path, f"gt_pose_{i_iter_val}_{ttt}.npy"),
                            result_dict['output_metas'][0]['rel_poses'][ttt])  # [1, 6, 3, 2]



            for key in metric_stp3.keys():
                metric_stp3[key] += result_dict['metric_stp3'][key]
            for key in time_used.keys():
                time_used[key] += result_dict['time'][key]
            loss, loss_dict = loss_func(loss_input)
            plan_loss += loss_dict.get('PlanRegLoss', 0)
            plan_loss += loss_dict.get('PlanRegLossLidar', 0)
            plan_loss += loss_dict.get('PoseLoss', 0)

            assert (input_occs!=target_occs).sum() == 0
            target_occs_iou = deepcopy(input_occs)
            target_occs_iou[target_occs_iou != 17] = 1
            target_occs_iou[target_occs_iou == 17] = 0
            # print("1预测的占据块数量:", (result_dict['iou_pred']==1).sum().item())  # 4400000个1, 2000000个0
            # print("1真实的占据块数量:", (target_occs_iou[:, 4:]==1).sum().item())
            # print("@sem_pred:", result_dict['sem_pred'].unique(), (result_dict['sem_pred']!=17).sum().item())
            # print("@iou_pred:", result_dict['iou_pred'].unique(), (result_dict['iou_pred']==1).sum().item(), (result_dict['iou_pred']==0).sum().item())

            
            # ForkedPdb().set_trace()
            # target_occs: 0-17
            # target_occs_iou: 0-1
            # result_dict['sem_pred']: 0, 17; shape: [1,6,200,200,16]
            # result_dict['iou_pred']: 0, 1
            # ForkedPdb().set_trace()
            CalMeanIou_sem._after_step(result_dict['sem_pred'], target_occs[:, 4:])
            CalMeanIou_vox._after_step(result_dict['iou_pred'], target_occs_iou[:, 4:])
            val_loss_list.append(loss.detach().cpu().numpy())
            if i_iter_val % print_freq == 0 and local_rank == 0:
                logger.info('[EVAL] Epoch %d Iter %5d: Loss: %.3f (%.3f)'%(
                    epoch, i_iter_val, loss.item(), np.mean(val_loss_list)))
                detailed_loss = []
                for loss_name, loss_value in loss_dict.items():
                    detailed_loss.append(f'{loss_name}: {loss_value:.5f}')
                detailed_loss = ', '.join(detailed_loss)
                logger.info(detailed_loss)


        metric_stp3 = {key:metric_stp3[key]/len(val_dataset_loader) for key in metric_stp3.keys()}
        time_used = {key:time_used[key]/len(val_dataset_loader) for key in time_used.keys()}
        # reduce for distributed
        if distributed:
            plan_loss = torch.tensor(plan_loss, dtype=torch.float64).cuda()
            dist.all_reduce(plan_loss)
            plan_loss /= world_size
            metric_stp3 = {key:torch.tensor(metric_stp3[key],dtype=torch.float64).cuda() for key in metric_stp3.keys()}
            for key in metric_stp3.keys():
                dist.all_reduce(metric_stp3[key])
                metric_stp3[key] /= world_size
            time_used = {key:torch.tensor(time_used[key],dtype=torch.float64).cuda() for key in time_used.keys()}
            for key in time_used.keys():
                dist.all_reduce(time_used[key])
                time_used[key] /= world_size
        metric_stp3.update(avg_l2=(metric_stp3['plan_L2_1s']+metric_stp3['plan_L2_2s']+metric_stp3['plan_L2_3s'])/3)
        metric_stp3.update(avg_obj_col=(metric_stp3['plan_obj_col_1s']+metric_stp3['plan_obj_col_2s']+metric_stp3['plan_obj_col_3s'])/3)
        metric_stp3.update(avg_obj_box_col=(metric_stp3['plan_obj_box_col_1s']+metric_stp3['plan_obj_box_col_2s']+metric_stp3['plan_obj_box_col_3s'])/3)
        metric_stp3.update(avg_obj_box_col_single=(metric_stp3['plan_obj_box_col_1s_single']+metric_stp3['plan_obj_box_col_2s_single']+metric_stp3['plan_obj_box_col_3s_single'])/3)
        metric_stp3.update(avg_obj_col_single=(metric_stp3['plan_obj_col_1s_single']+metric_stp3['plan_obj_col_2s_single']+metric_stp3['plan_obj_col_3s_single'])/3)
        metric_stp3.update(avg_l2_single=(metric_stp3['plan_L2_1s_single']+metric_stp3['plan_L2_2s_single']+metric_stp3['plan_L2_3s_single'])/3)
        for key in metric_stp3.keys():
            metric_stp3[key] = metric_stp3[key].item()
            if key in [
                'plan_L2_1s_single', 'plan_L2_2s_single', 'plan_L2_3s_single',
                'plan_obj_box_col_1s_single', 'plan_obj_box_col_2s_single',
                'plan_obj_box_col_3s_single', 'avg_l2_single', 'avg_obj_box_col_single'
            ]:
                logger.info(f'[{key}] is {metric_stp3[key]}')
            else:
                logger.info(f'{key} is {metric_stp3[key]}')

        #logger.info(f'metric_stp3 is {metric_stp3}')
        logger.info(f'time_used is {time_used}')
        logger.info(f'FPS is {1/time_used["per_frame"]}')

        val_miou, _ = CalMeanIou_sem._after_epoch()
        val_iou, _ = CalMeanIou_vox._after_epoch()
        
        del target_occs, input_occs
        plan_loss = plan_loss/len(val_dataset_loader)
        if plan_loss < best_plan_loss:
            best_plan_loss = plan_loss
        logger.info(f'PlanRegLoss is {plan_loss} while the best plan loss is {best_plan_loss}')
        #logger.info(f'PlanRegLoss is {plan_loss/len(val_dataset_loader)}')
        best_val_iou = [max(best_val_iou[i], val_iou[i]) for i in range(len(best_val_iou))]
        best_val_miou = [max(best_val_miou[i], val_miou[i]) for i in range(len(best_val_miou))]
        #logger.info(f'PlanRegLoss is {plan_loss/len(val_dataset_loader)}')
        logger.info(f'Current val iou is {val_iou} while the best val iou is {best_val_iou}')
        logger.info(f'Current val miou is {val_miou} while the best val miou is {best_val_miou}')
        logger.info(f'avg val iou is {(val_iou[1]+val_iou[3]+val_iou[5])/3}')
        logger.info(f'avg val miou is {(val_miou[1]+val_miou[3]+val_miou[5])/3}')
        torch.cuda.empty_cache()


if __name__ == '__main__':
    # Training settings
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('--py-config', default='config/tpv_lidarseg.py')
    parser.add_argument('--work-dir', type=str, default='./out/tpv_lidarseg')
    parser.add_argument('--resume-from', type=str, default='')
    parser.add_argument('--iter-resume', action='store_true', default=False)
    parser.add_argument('--seed', type=int, default=42)
    args = parser.parse_args()
    
    ngpus = torch.cuda.device_count()
    args.gpus = ngpus
    print(args)

    if ngpus > 1:
        torch.multiprocessing.spawn(main, args=(args,), nprocs=args.gpus)
    else:
        main(0, args)
