from pathlib import Path
import time
import shutil
import torch
import argparse
import logging

from accelerate import Accelerator

from utils import load_module
from utils.pyutils import AverageMeter
from utils.utility import mask_iou, Eval_Fmeasure, save_mask_ms3
from utils.system import setup_logging
import pdb

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("config", type=str, help="path to the config file")
    parser.add_argument("--session-name", default="MS3", type=str, help="the MS3 setting")

    parser.add_argument("--test_batch_size", default=1, type=int)
    parser.add_argument("--num_workers", default=8, type=int)

    parser.add_argument("--weights", type=str, required=True, help="path to the weights")
    parser.add_argument("--save_pred_mask", action='store_true', default=False, help="save predicted masks or not")
    parser.add_argument('--log_dir', default='./test_logs', type=str)

    args = parser.parse_args()

    # Log directory
    log_dir = Path(args.log_dir)
    log_dir.mkdir(exist_ok=True, parents=True)
    # Logs
    prefix = args.session_name
    log_dir = log_dir / time.strftime(prefix + '_%Y%m%d-%H%M%S')
    args.log_dir = log_dir

    # Save scripts
    script_path = log_dir / 'scripts'
    script_path.mkdir(exist_ok=True, parents=True)

    scripts_to_save = [args.config, 'test_ms3.py']
    for script in scripts_to_save:
        dst_path = script_path / script
        dst_path.mkdir(exist_ok=True, parents=True)
        dst_path = str(dst_path)
        try:
            shutil.copy(script, dst_path)
        except IOError:
            import os

            os.makedirs(os.path.dirname(dst_path), exist_ok=True)
            shutil.copy(script, dst_path)

    # Set logger
    log_path = log_dir / 'log'
    log_path.mkdir(exist_ok=True, parents=True)

    accelerator = Accelerator()
    setup_logging(filename=str(log_path / 'log.txt'))
    logger = logging.getLogger(__name__)
    logger.info(f'==> Arguments: {vars(args)}')
    logger.info(f'==> Experiment: {args.session_name}')

    module_loader = load_module(args.config)
    # Model
    model = module_loader.model
    model.load_state_dict(torch.load(args.weights, map_location='cpu'))
    model = accelerator.prepare_model(model)
    logger.info(f'Load trained model from {args.weights}')

    # Test data
    test_dataset = module_loader.test_dataset
    test_dataloader = torch.utils.data.DataLoader(test_dataset,
                                                  batch_size=args.test_batch_size,
                                                  shuffle=False,
                                                  num_workers=args.num_workers,
                                                  pin_memory=True,
                                                  drop_last=False,
                                                  collate_fn=test_dataset.collate_fn
                                                  )
    accelerator.prepare_data_loader(test_dataloader)

    avg_meter = AverageMeter('miou', 'F_score')

    if args.save_pred_mask:
        mask_save_path = str(log_dir / 'pred_masks')

    # Test
    model.eval()
    with torch.no_grad():
        for n_iter, batch_data in enumerate(test_dataloader):
            imgs, audio, mask, video_name_list = batch_data  # [bs, 5, 3, 224, 224], [bs, 5, 1, 96, 64], [bs, 1, 1, 224, 224]

            imgs = imgs.to(accelerator.device)
            audio = audio.to(accelerator.device)
            mask = mask.to(accelerator.device)

            B, frame, C, H, W = imgs.shape
            imgs = imgs.view(B * frame, C, H, W)
            mask = mask.view(B * frame, H, W)
            audio = audio.view(-1, audio.shape[2], audio.shape[3], audio.shape[4])

            output = model(imgs, audio)  # [5, 1, 224, 224] = [bs=1 * T=5, 1, 224, 224]


            miou = mask_iou(output.squeeze(1), mask)
            F_score = Eval_Fmeasure(output.squeeze(1), mask)
            avg_meter.add({'miou': miou, 'F_score': F_score})
            print(f'n_iter: {n_iter}, iou: {miou}, F_score: {F_score}')
            if args.save_pred_mask:
                save_mask_ms3(output.squeeze(1), mask_save_path, video_name_list)

        miou = avg_meter.pop('miou')
        F_score = avg_meter.pop('F_score')
        print(f'test miou: {miou.item()}')
        print(f'test F_score: {F_score}')
        logger.info(f'test miou: {miou.item()}, F_score: {F_score}')
