import time
import shutil
import argparse
import logging
from pathlib import Path

import torch
from accelerate import Accelerator
from accelerate.utils import set_seed

from utils import load_module
from utils.pyutils import AverageDict, track
from utils.system import setup_logging

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("config", type=str, help="path to the config file")
    parser.add_argument("--session-name", default="MS3", type=str, help="the MS3 setting")

    parser.add_argument("--batch_per_gpu", default=4, type=int)
    parser.add_argument("--val_batch_size", default=1, type=int)
    parser.add_argument("--epoch", default=15, type=int)
    parser.add_argument("--lr", default=0.0001, type=float)
    parser.add_argument("--num_workers", default=8, type=int)
    parser.add_argument("--wt_dec", default=0, type=float)

    parser.add_argument("--load_s4_params", action='store_true', default=False,
                        help='use S4 parameters for initilization')
    parser.add_argument("--trained_s4_model_path", type=str, default='', help='pretrained S4 model')

    parser.add_argument("--seed", type=int, default=123, help="random seed")

    parser.add_argument('--log_dir', default='./train_logs', type=str)

    parser.add_argument('--tensorboard', action='store_true', default=False, help='use tensorboardX')

    parser.add_argument("--debug", action='store_true', default=False, help="debug")

    args = parser.parse_args()

    # Fix seed
    set_seed(args.seed)
    torch.backends.cudnn.benchmark = True

    # Log directory
    log_dir = Path(args.log_dir)
    # Logs
    prefix = args.session_name if not args.debug else 'debug'
    log_dir = log_dir / time.strftime(prefix + '_%Y%m%d-%H%M%S')
    log_dir.mkdir(exist_ok=True, parents=True)
    args.log_dir = str(log_dir)

    scripts_to_save = [args.config, 'train_ms3.py']
    for script in scripts_to_save:
        try:
            shutil.copy(script, log_dir)
        except IOError:
            import os

            os.makedirs(os.path.dirname(log_dir), exist_ok=True)
            shutil.copy(script, log_dir)

    accelerator = Accelerator(log_with='tensorboard' if args.tensorboard else None, project_dir=str(log_dir))

    setup_logging(filename=str(log_dir / 'log.txt'))
    logger = logging.getLogger(__name__)
    logger.info(f'==> Arguments: {vars(args)}')
    logger.info(f'==> Experiment: {args.session_name}')
    print(f'==> Experiment: {args.session_name}')
    print(vars(args))

    accelerator.init_trackers(args.session_name, config=vars(args))

    module_loader = load_module(args.config)
    # Model
    model = module_loader.model

    # load pretrained S4 model
    if args.load_s4_params:  # fine-tune single sound source segmentation model
        model.load_state_dict(state_dict=torch.load(args.trained_s4_model_path, map_location='cpu'), strict=False)
        logger.info(f"==> Reload pretrained S4 model from {args.trained_s4_model_path}")
        print(f"==> Reload pretrained S4 model from {args.trained_s4_model_path}")

    model = module_loader.model
    param_info = f"==> Total params: {sum(p.numel() for p in model.parameters()) / 1e6:.2f}M"
    logger.info(param_info)
    print(param_info)

    # Data
    train_dataset = module_loader.train_dataset
    train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=args.batch_per_gpu,
                                                   shuffle=True,
                                                   num_workers=args.num_workers,
                                                   pin_memory=True,
                                                   drop_last=True,
                                                   collate_fn=train_dataset.collate_fn)

    step_per_epoch = len(train_dataset) // (args.batch_per_gpu * accelerator.num_processes)
    max_step = step_per_epoch * args.epoch

    val_dataset = module_loader.val_dataset
    val_dataloader = torch.utils.data.DataLoader(val_dataset,
                                                 batch_size=args.val_batch_size,
                                                 shuffle=False,
                                                 num_workers=args.num_workers,
                                                 pin_memory=True,
                                                 drop_last=False,
                                                 collate_fn=val_dataset.collate_fn)

    # Optimizer
    optimizer = torch.optim.AdamW(model.parameters(), args.lr, weight_decay=args.wt_dec)
    scheduler = torch.optim.lr_scheduler.PolynomialLR(
        optimizer,
        total_iters=max_step * accelerator.num_processes,
        power=0.9
    )

    model, train_dataloader, val_dataloader, optimizer, scheduler = accelerator.prepare(model,
                                                                                        train_dataloader,
                                                                                        val_dataloader,
                                                                                        optimizer,
                                                                                        scheduler)

    loss_fn = module_loader.loss_fn
    metric = module_loader.metric

    avg_meter = AverageDict()

    # Train
    best_epoch = 0
    global_step = 0
    mask_num = 5
    max_miou = 0
    for epoch in range(args.epoch):
        model.train()
        for batch_data in track(train_dataloader, description=f"Train epoch {epoch}",
                                disable=not accelerator.is_local_main_process or args.debug):
            imgs, audio, mask = batch_data  # [bs, 5, 3, 224, 224], [bs, 5, 1, 96, 64], [bs, 5 or 1, 1, 224, 224]

            B, frame, C, H, W = imgs.shape
            imgs = imgs.view(B * frame, C, H, W)
            mask = mask.view(B * mask_num, 1, H, W)
            audio = audio.view(-1, audio.shape[2], audio.shape[3], audio.shape[4])  # [B*T, 1, 96, 64]

            output, aux = model(imgs, audio)  # [bs*5, 1, 224, 224]
            loss, loss_dict = loss_fn(output,
                                      mask.unsqueeze(1).unsqueeze(1),
                                      aux)

            optimizer.zero_grad()
            accelerator.backward(loss)
            optimizer.step()
            scheduler.step()

            avg_meter.add(total_loss=loss.item(), **loss_dict)

            global_step += 1
            if (global_step - 1) % 20 == 0:
                total_loss = avg_meter.pop('total_loss')
                each_loss = {name: avg_meter.pop(name) for name in avg_meter.keys if name != 'total_loss'}
                cur_lr = optimizer.param_groups[-1]['lr']
                train_log = f'Iter:{global_step - 1:5d}/{max_step:5d}, Total_Loss:{total_loss}, {each_loss}, lr: {cur_lr:.6f}'
                logger.info(train_log)
                print(train_log)
        # Validation:
        model.eval()
        with torch.no_grad():
            for batch_data in track(val_dataloader, f"Val epoch {epoch}",
                                    disable=not accelerator.is_local_main_process or args.debug):
                imgs, audio, mask, _ = batch_data  # [bs, 5, 3, 224, 224], [bs, 5, 1, 96, 64], [bs, 5, 1, 224, 224]

                B, frame, C, H, W = imgs.shape
                imgs = imgs.view(B * frame, C, H, W)
                mask = mask.view(B * frame, H, W)
                audio = audio.view(-1, audio.shape[2], audio.shape[3], audio.shape[4])

                output = model(imgs, audio)  # [bs*5, 1, 224, 224]

                all_outputs, all_label = accelerator.gather_for_metrics((output.squeeze(1), mask))
                metric.add_batch(pred=all_outputs, label=all_label)

            miou = metric.compute()['mask_iou']
            if miou > max_miou and accelerator.is_main_process:
                model_save_path = log_dir / 'best.pth'
                torch.save(accelerator.get_state_dict(model, unwrap=True), model_save_path)
                best_epoch = epoch
                logger.info(f'save best model to {model_save_path}')
                print(f'save best model to {model_save_path}')
                max_miou = miou
            accelerator.wait_for_everyone()

            val_log = f'Epoch: {epoch}, Miou: {miou}, maxMiou: {max_miou}'
            logger.info(val_log)
            print(val_log)
        accelerator.log(dict(
            mIoU=miou,
            lr=optimizer.param_groups[0]['lr']
        ), step=epoch)
        avg_meter.clear()

    logger.info(f'best val Miou {max_miou} at peoch: {best_epoch}')
    print(f'best val Miou {max_miou} at peoch: {best_epoch}')
    accelerator.end_training()

