from pathlib import Path
import time
import shutil
import torch
import numpy as np
import argparse
import logging
from accelerate import Accelerator

from dataset.avss_dataset import get_v2_pallete

from utils import load_module
from utils.system import setup_logging
from utils.vis_mask import save_color_mask
from utils.compute_color_metrics import calc_color_miou_fscore

import pdb

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("config", type=str, help="path to the config file")
    parser.add_argument("--session_name", default="AVSS", type=str, help="the AVSS setting")

    parser.add_argument("--test_batch_size", default=1, type=int)
    parser.add_argument("--num_workers", default=8, type=int)

    parser.add_argument("--weights", type=str)
    parser.add_argument("--save_pred_mask", action='store_true', default=False, help="save predited masks or not")
    parser.add_argument('--log_dir', default='./test_logs', type=str)

    args = parser.parse_args()

    # Log directory
    log_dir = Path(args.log_dir)
    log_dir.mkdir(exist_ok=True, parents=True)
    # Logs
    prefix = args.session_name
    log_dir = log_dir / time.strftime(prefix + '_%Y%m%d-%H%M%S')
    if log_dir.exists():
        log_dir = log_dir / np.random.randint(1, 10)
    args.log_dir = log_dir

    # Save scripts
    script_path = log_dir / 'scripts'
    script_path.mkdir(exist_ok=True, parents=True)

    scripts_to_save = [args.config, 'test_avss.py']
    for script in scripts_to_save:
        dst_path = script_path / script
        dst_path.mkdir(exist_ok=True, parents=True)
        dst_path = str(dst_path)
        try:
            shutil.copy(script, dst_path)
        except IOError:
            import os

            os.makedirs(os.path.dirname(dst_path), exist_ok=True)
            shutil.copy(script, dst_path)

    # Set logger
    log_path = log_dir / 'log'
    log_path.mkdir(exist_ok=True, parents=True)

    accelerator = Accelerator()

    setup_logging(filename=str(log_path / 'log.txt'))
    logger = logging.getLogger(__name__)
    logger.info(f'==> Arguments: {vars(args)}')
    logger.info(f'==> Experiment: {args.session_name}')

    module_loader = load_module(args.config)
    # Model
    model = module_loader.model
    model.load_state_dict(torch.load(args.weights, map_location='cpu'))
    model = accelerator.prepare_model(model)
    logger.info(f'Load trained model from {args.weights}')

    # Test data
    test_dataset = module_loader.test_dataset
    test_dataloader = torch.utils.data.DataLoader(test_dataset,
                                                  batch_size=args.test_batch_size,
                                                  shuffle=False,
                                                  num_workers=args.num_workers,
                                                  pin_memory=True,
                                                  drop_last=False,
                                                  collate_fn=test_dataset.collate_fn
                                                  )

    # for save predicted rgb masks
    v2_pallete = get_v2_pallete(test_dataset.label_idx_path)

    # metrics
    miou_pc = torch.zeros(module_loader.num_classes, device=accelerator.device)  # miou value per class (total sum)
    Fs_pc = torch.zeros(module_loader.num_classes, device=accelerator.device)  # f-score per class (total sum)
    cls_pc = torch.zeros(module_loader.num_classes, device=accelerator.device)  # count per class
    # Test
    model.eval()
    with torch.no_grad():
        for n_iter, batch_data in enumerate(test_dataloader):
            imgs, audio, mask, vid_temporal_mask_flag, gt_temporal_mask_flag, video_name_list = batch_data  # [bs, 5, 3, 224, 224], [bs, 5, 1, 96, 64], [bs, 1, 1, 224, 224]
            imgs = imgs.to(accelerator.device)
            audio = audio.to(accelerator.device)
            mask = mask.to(accelerator.device)
            # ! notice:
            vid_temporal_mask_flag = vid_temporal_mask_flag.to(accelerator.device)
            gt_temporal_mask_flag = gt_temporal_mask_flag.to(accelerator.device)

            B, frame, C, H, W = imgs.shape
            imgs = imgs.view(B * frame, C, H, W)
            mask = mask.view(B * frame, H, W)
            audio = audio.view(B * 10, audio.shape[-3], audio.shape[-2], audio.shape[-1])  # [bs*10, 1, 96, 64]

            # ! notice
            vid_temporal_mask_flag = vid_temporal_mask_flag.view(B * frame)  # [B*T]
            gt_temporal_mask_flag = gt_temporal_mask_flag.view(B * frame)  # [B*T]

            output = model(imgs, audio, vid_temporal_mask_flag)  # [bs*5, 21, 224, 224]
            if args.save_pred_mask:
                mask_save_path = log_dir / 'pred_masks'
                save_color_mask(output, mask_save_path, video_name_list, v2_pallete, True,
                                (360, 240), T=10)

            _miou_pc, _fscore_pc, _cls_pc,_ = calc_color_miou_fscore(output, mask)
            # compute miou, J-measure
            miou_pc += _miou_pc
            cls_pc += _cls_pc
            # compute f-score, F-measure
            Fs_pc += _fscore_pc

            batch_iou = miou_pc / cls_pc
            batch_iou[torch.isnan(batch_iou)] = 0
            batch_iou = torch.sum(batch_iou) / torch.sum(cls_pc != 0)
            batch_fscore = Fs_pc / cls_pc
            batch_fscore[torch.isnan(batch_fscore)] = 0
            batch_fscore = torch.sum(batch_fscore) / torch.sum(cls_pc != 0)
            print(
                f'n_iter: {n_iter}, iou: {batch_iou}, F_score: {batch_fscore}, cls_num: {torch.sum(cls_pc != 0).item()}')

        miou_pc = miou_pc / cls_pc
        print(f"[test miou] {torch.sum(torch.isnan(miou_pc)).item()} classes are not predicted in this batch")
        miou_pc[torch.isnan(miou_pc)] = 0
        miou = torch.mean(miou_pc).item()
        miou_noBg = torch.mean(miou_pc[:-1]).item()
        f_score_pc = Fs_pc / cls_pc
        print(f"[test fscore] {torch.sum(torch.isnan(f_score_pc)).item()} classes are not predicted in this batch")
        f_score_pc[torch.isnan(f_score_pc)] = 0
        f_score = torch.mean(f_score_pc).item()
        f_score_noBg = torch.mean(f_score_pc[:-1]).item()

        logger.info(
            f'test | cls {torch.sum(cls_pc != 0).item()}, miou: {miou:.4f}, miou_noBg: {miou_noBg:.4f}, F_score: {f_score:.4f}, F_score_noBg: {f_score_noBg:.4f}')
        print(
            f'test | cls {torch.sum(cls_pc != 0).item()}, miou: {miou:.4f}, miou_noBg: {miou_noBg:.4f}, F_score: {f_score:.4f}, F_score_noBg: {f_score_noBg:.4f}')
