import time
import shutil
import numpy as np
import argparse
import logging
from pathlib import Path

import torch
from accelerate import Accelerator
from accelerate.utils import set_seed

from utils import load_module
from utils.pyutils import AverageDict, track
from utils.system import setup_logging

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("config", type=str, help="path to the config file")
    parser.add_argument("--session-name", default="AVSS", type=str, help="the AVSS setting")

    parser.add_argument("--batch_per_gpu", default=4, type=int)
    parser.add_argument("--val_batch_size", default=1, type=int)
    parser.add_argument("--epoch", default=30, type=int)
    parser.add_argument("--lr", default=0.0001, type=float)
    parser.add_argument("--num_workers", default=8, type=int)
    parser.add_argument("--wt_dec", default=0., type=float)

    parser.add_argument("--seed", type=int, default=123, help="random seed")

    parser.add_argument("--start_eval_epoch", default=10, type=int)
    parser.add_argument("--eval_interval", default=2, type=int)

    parser.add_argument('--log_dir', default='./train_logs', type=str)

    parser.add_argument('--tensorboard', action='store_true', default=False, help='use tensorboardX')

    parser.add_argument("--debug", action='store_true', default=False, help="debug")

    args = parser.parse_args()

    # Fix seed
    set_seed(args.seed)
    torch.backends.cudnn.benchmark = True

    # Log directory
    log_dir = Path(args.log_dir)
    # Logs
    prefix = args.session_name if not args.debug else 'debug'
    log_dir = log_dir / time.strftime(prefix + '_%Y%m%d-%H%M%S')
    log_dir.mkdir(exist_ok=True, parents=True)
    args.log_dir = str(log_dir)

    scripts_to_save = [args.config, 'train_avss.py']
    for script in scripts_to_save:
        try:
            shutil.copy(script, log_dir)
        except IOError:
            import os

            os.makedirs(os.path.dirname(log_dir), exist_ok=True)
            shutil.copy(script, log_dir)

    accelerator = Accelerator(log_with='tensorboard' if args.tensorboard else None, project_dir=log_dir)

    setup_logging(filename=str(log_dir / 'log.txt'))
    logger = logging.getLogger(__name__)
    logger.info(f'==> Arguments: {vars(args)}')
    logger.info(f'==> Experiment: {args.session_name}')
    print(f'==> Experiment: {args.session_name}')
    print(vars(args))

    accelerator.init_trackers(log_dir.stem, config=vars(args))

    module_loader = load_module(args.config)
    # Model
    model = module_loader.model
    param_info = f"==> Total params: {sum(p.numel() for p in model.parameters()) / 1e6:.2f}M"
    logger.info(param_info)
    print(param_info)

    # Data
    train_dataset = module_loader.train_dataset
    N_CLASSES = train_dataset.num_classes
    train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=args.batch_per_gpu,
                                                   shuffle=True,
                                                   num_workers=args.num_workers,
                                                   pin_memory=True,
                                                   drop_last=True,
                                                   collate_fn=train_dataset.collate_fn
                                                   )
    step_per_epoch = len(train_dataset) // (args.batch_per_gpu * accelerator.num_processes)
    max_step = step_per_epoch * args.epoch

    val_dataset = module_loader.val_dataset
    val_dataloader = torch.utils.data.DataLoader(val_dataset,
                                                 batch_size=args.val_batch_size,
                                                 shuffle=False,
                                                 num_workers=args.num_workers,
                                                 pin_memory=True,
                                                 drop_last=False,
                                                 collate_fn=val_dataset.collate_fn
                                                 )

    # Optimizer
    optimizer = torch.optim.AdamW(model.parameters(), args.lr, weight_decay=args.wt_dec)
    scheduler = torch.optim.lr_scheduler.PolynomialLR(
        optimizer,
        total_iters=max_step * accelerator.num_processes,
        power=0.9
    )
    model, train_dataloader, val_dataloader, optimizer, scheduler = accelerator.prepare(model,
                                                                                        train_dataloader,
                                                                                        val_dataloader,
                                                                                        optimizer,
                                                                                        scheduler)

    loss_fn = module_loader.loss_fn
    metric = module_loader.metric

    avg_meter = AverageDict()

    # Train
    best_epoch = 0
    global_step = 0
    miou_list = []
    max_miou = 0
    miou_noBg_list = []
    fscore_list, fscore_noBg_list = [], []
    max_fs, max_fs_noBg = 0, 0
    mask_num = 10

    for epoch in range(args.epoch):
        model.train()
        for batch_data in track(train_dataloader, description=f"Train epoch {epoch}",
                                disable=not accelerator.is_local_main_process or args.debug):
            imgs, audio, mask, vid_temporal_mask_flag, gt_temporal_mask_flag = batch_data  # [bs, 5, 3, 224, 224], ->[bs, 5, 1, 96, 64], [bs, 10, 1, 224, 224]

            B, frame, C, H, W = imgs.shape
            imgs = imgs.view(B * frame, C, H, W)
            mask = mask.view(B * mask_num, H, W)
            audio = audio.view(B * 10, audio.shape[-3], audio.shape[-2], audio.shape[-1])  # [bs*10, 1, 96, 64]
            # ! notice
            vid_temporal_mask_flag = vid_temporal_mask_flag.view(B * frame)  # [B*T]
            gt_temporal_mask_flag = gt_temporal_mask_flag.view(B * frame)  # [B*T]

            # ! notice:
            # audio_feature = audio_feature * vid_temporal_mask_flag.unsqueeze(-1)
            # pdb.set_trace()

            output, aux = model(imgs, audio, vid_temporal_mask_flag)  # [bs*5, 24, 224, 224]
            loss, loss_dict = loss_fn(output,
                                      mask.unsqueeze(1).unsqueeze(1),
                                      gt_temporal_mask_flag,
                                      aux)
            optimizer.zero_grad()
            accelerator.backward(loss)
            optimizer.step()
            scheduler.step()

            avg_meter.add(total_loss=loss.item(), **loss_dict)

            global_step += 1
            if (global_step - 1) % 100 == 0:
                total_loss = avg_meter.pop('total_loss')
                each_loss = {name: avg_meter.pop(name) for name in avg_meter.keys if name != 'total_loss'}
                cur_lr = optimizer.param_groups[-1]['lr']
                train_log = f'Iter:{global_step - 1:5d}/{max_step:5d}, Total_Loss:{total_loss}, {each_loss}, lr: {cur_lr:.6f}'
                logger.info(train_log)
                print(train_log)

        # Validation:
        if not (epoch >= args.start_eval_epoch and epoch % args.eval_interval == 0):
            continue

        model.eval()
        with torch.no_grad():
            for batch_data in track(val_dataloader, description=f"Val epoch {epoch}",
                                    disable=not accelerator.is_local_main_process or args.debug):
                imgs, audio, mask, vid_temporal_mask_flag, gt_temporal_mask_flag, _ = batch_data  # [bs, 5, 3, 224, 224], [bs, 5, 1, 96, 64], [bs, 5, 1, 224, 224]
                B, frame, C, H, W = imgs.shape
                imgs = imgs.view(B * frame, C, H, W)
                mask = mask.view(B * frame, H, W)
                audio = audio.view(B * 10, audio.shape[-3], audio.shape[-2], audio.shape[-1])  # [bs*10, 1, 96, 64]
                # ! notice
                vid_temporal_mask_flag = vid_temporal_mask_flag.view(B * frame)  # [B*T]
                gt_temporal_mask_flag = gt_temporal_mask_flag.view(B * frame)  # [B*T]

                output = model(imgs, audio, vid_temporal_mask_flag)  # [bs*5, 21, 224, 224]

                all_outputs, all_label = accelerator.gather_for_metrics((output, mask))
                metric.add_batch(pred=all_outputs, label=all_label)

            result = metric.compute()
            miou = result['miou']
            miou_noBg = result['miou_noBg']
            f_score = result['f_score']
            f_score_noBg = result['f_score_noBg']

            if miou > max_miou and accelerator.is_main_process:
                model_save_path = log_dir / 'miou_best.pth'
                torch.save(accelerator.get_state_dict(model, unwrap=True), model_save_path)
                best_epoch = epoch
                logger.info(f'save miou best model to {model_save_path}')
                print(f'save miou best model to {model_save_path}')
            if (miou + f_score) > (max_miou + max_fs) and accelerator.is_main_process:
                model_save_path = log_dir / f'miou_and_fscore_best.pth'
                torch.save(accelerator.get_state_dict(model, unwrap=True), model_save_path)
                best_epoch = epoch
                logger.info(f'save miou and fscore best model to {model_save_path}')
                print(f'save miou and fscore best model to {model_save_path}')
            accelerator.wait_for_everyone()

            miou_list.append(miou)
            miou_noBg_list.append(miou_noBg)
            max_miou = max(miou_list)
            max_miou_noBg = max(miou_noBg_list)
            fscore_list.append(f_score)
            fscore_noBg_list.append(f_score_noBg)
            max_fs = max(fscore_list)
            max_fs_noBg = max(fscore_noBg_list)

            val_log = (
                f'Epoch: {epoch}, Miou: {miou}, maxMiou: {max_miou}, Miou(no bg): {miou_noBg}, maxMiou (no bg): {max_miou_noBg}'
                f' Fscore: {f_score}, maxFs: {max_fs}, Fscore(no bg): {f_score_noBg}, max Fscore (no bg): {max_fs_noBg}'
            )
            logger.info(val_log)
            print(val_log)

        accelerator.log(dict(
            mIoU=miou,
            mIoU_noBg=miou_noBg,
            F_score=f_score,
            F_score_noBg=f_score_noBg,
            lr=optimizer.param_groups[0]['lr']
        ), step=epoch)

    logger.info(f'best val Miou {max_miou} at peoch: {best_epoch}')
    print(f'best val Miou {max_miou} at peoch: {best_epoch}')
