#!/usr/bin/env python
import os
import os.path as osp
import argparse
import logging
import time
import socket
import warnings

import torch
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter

from common.solver.build import build_optimizer, build_scheduler
from common.utils.checkpoint import CheckpointerV2
from common.utils.logger import setup_logger
from common.utils.metric_logger import MetricLogger
from common.utils.torch_util import set_random_seed
from models.build import build_model_2d, build_model_3d,  build_model_2D_FeatureMapper, build_model_3D_FeatureMapper
from data.build import build_dataloader
from data.utils.validate import validate_3d
from models.losses import Compute_loss_c, Compute_loss_a

def log_weights(model):            
    weight_logs = []
    for name, param in model.named_parameters():
        if param.requires_grad:
            mean = param.data.mean().item()
            std = param.data.std().item()
            weight_logs.append(f"model:{name} - Mean: {mean:.4f}, Std: {std:.4f}")
    return weight_logs

def parse_args():
    parser = argparse.ArgumentParser(description='LSB training')
    parser.add_argument(
        '--cfg',
        dest='config_file',
        default='',
        metavar='FILE',
        help='path to config file',
        type=str,        
    )
    parser.add_argument(
        '--output_dir',
        dest='output_dir',
        default='',
        metavar='DIR',
        help='output directory',
        type=str,
    )
    parser.add_argument(
        'opts',
        help='Modify config options using the command-line',
        default=None,
        nargs=argparse.REMAINDER,
    )
    args = parser.parse_args()
    return args

def init_metric_logger(metric_list):
    new_metric_list = []
    for metric in metric_list:
        if isinstance(metric, (list, tuple)):
            new_metric_list.extend(metric)
        else:
            new_metric_list.append(metric)
    metric_logger = MetricLogger(delimiter='  ')
    metric_logger.add_meters(new_metric_list)
    return metric_logger

def update_ema_variables(model, ema_model, alpha, global_step):
    # Use the true average until the exponential average is more correct
    alpha = min(1 - 1 / (global_step + 1), alpha)
    for ema_param, param in zip(ema_model.parameters(), model.parameters()):
        ema_param.data.mul_(alpha).add_(param.data, alpha=1 - alpha)

def train(cfg, output_dir='', run_name=''):
    # ---------------------------------------------------------------------------- #
    # Build models, optimizer, scheduler, checkpointer, etc.
    # ---------------------------------------------------------------------------- #
    logger = logging.getLogger('LSB.train')

    set_random_seed(cfg.RNG_SEED)

    # build 2d model
    model_2d, train_metric_2d = build_model_2d(cfg)
    model_2d_ema, _ = build_model_2d(cfg)
    for param in model_2d_ema.parameters():
        param.detach_()
    logger.info('Build 2D model:\n{}'.format(str(model_2d)))
    num_params = sum(param.numel() for param in model_2d.parameters())
    print('#Parameters: {:.2e}'.format(num_params))

    # build 3d model
    model_3d, train_metric_3d = build_model_3d(cfg)
    logger.info('Build 3D model:\n{}'.format(str(model_3d)))
    num_params = sum(param.numel() for param in model_3d.parameters())
    print('#Parameters: {:.2e}'.format(num_params))

    # build 2D FeatureMapping model
    model_FeatureMapping_2d = build_model_2D_FeatureMapper(cfg)
    logger.info('Build FeatureMapping model:\n{}'.format(str(model_FeatureMapping_2d)))
    num_params = sum(param.numel() for param in model_FeatureMapping_2d.parameters())
    print('#Parameters FeatureMapping_2d: {:.2e}'.format(num_params))
    
    # build 3D FeatureMapping model
    model_FeatureMapping_3d = build_model_3D_FeatureMapper(cfg)
    logger.info('Build FeatureMapping model:\n{}'.format(str(model_FeatureMapping_3d)))
    num_params = sum(param.numel() for param in model_FeatureMapping_3d.parameters())
    print('#Parameters FeatureMapping_3d: {:.2e}'.format(num_params))


    model_2d = model_2d.cuda()
    model_3d = model_3d.cuda()
    model_FeatureMapping_2d = model_FeatureMapping_2d.cuda()
    model_FeatureMapping_3d = model_FeatureMapping_3d.cuda()
    model_2d_ema = model_2d_ema.cuda()


    # build optimizer
    optimizer_2d = build_optimizer(cfg, model_2d)
    optimizer_3d = build_optimizer(cfg, model_3d)
    optimizer_FeatureMapping_2d = build_optimizer(cfg, model_FeatureMapping_2d)
    optimizer_FeatureMapping_3d = build_optimizer(cfg, model_FeatureMapping_3d)

    # build lr scheduler
    scheduler_2d = build_scheduler(cfg, optimizer_2d)
    scheduler_3d = build_scheduler(cfg, optimizer_3d)
    scheduler_FeatureMapping_2d = build_scheduler(cfg, optimizer_FeatureMapping_2d)
    scheduler_FeatureMapping_3d = build_scheduler(cfg, optimizer_FeatureMapping_3d)

    checkpointer_3d = CheckpointerV2(model_3d,
                                     optimizer=optimizer_3d,
                                     scheduler=scheduler_3d,
                                     save_dir=output_dir,
                                     logger=logger,
                                     postfix='_3d',
                                     max_to_keep=cfg.TRAIN.MAX_TO_KEEP)
    checkpoint_data_3d = checkpointer_3d.load(cfg.RESUME_PATH, resume=cfg.AUTO_RESUME, resume_states=cfg.RESUME_STATES)

    # build tensorboard logger (optionally by comment)
    if output_dir:
        tb_dir = osp.join(output_dir, 'tb.{:s}'.format(run_name))
        summary_writer = SummaryWriter(tb_dir)
    else:
        summary_writer = None

    # ---------------------------------------------------------------------------- #
    # Train
    # ---------------------------------------------------------------------------- #
    max_iteration = cfg.SCHEDULER.MAX_ITERATION
    start_iteration = checkpoint_data_3d.get('iteration', 0)

    # build data loader
    # Reset the random seed again in case the initialization of models changes the random state.
    set_random_seed(cfg.RNG_SEED)
    train_dataloader_src = build_dataloader(cfg, mode='train', domain='source', start_iteration=start_iteration)
    train_dataloader_trg = build_dataloader(cfg, mode='train', domain='target', start_iteration=start_iteration)
    train_dataloader_brg = build_dataloader(cfg, mode='train', domain='bridge', start_iteration=start_iteration)
    val_period = cfg.VAL.PERIOD
    val_dataloader = build_dataloader(cfg, mode='val', domain='target') if val_period > 0 else None

    best_metric_name = 'best_{}'.format(cfg.VAL.METRIC_3d)
    best_metric = checkpoint_data_3d.get(best_metric_name, None)
    best_metric_iter = -1
    logger.info('Start training from iteration {}'.format(start_iteration))

    # add metrics
    train_metric_logger = init_metric_logger([train_metric_2d, train_metric_3d])
    val_metric_logger = MetricLogger(delimiter='  ')

    def setup_train():
        # set training mode
        model_2d.train()
        model_3d.train()
        model_FeatureMapping_2d.train()    
        model_FeatureMapping_3d.train()

        # reset metric
        train_metric_logger.reset()

    def setup_validate():
        # set evaluate mode
        model_2d.eval()
        model_2d_ema.eval()
        model_FeatureMapping_2d.eval()
        model_FeatureMapping_3d.eval()
        # reset metric
        val_metric_logger.reset()

    if cfg.TRAIN.CLASS_WEIGHTS:
        class_weights = torch.tensor(cfg.TRAIN.CLASS_WEIGHTS).cuda()
    else:
        class_weights = None


    setup_train()
    end = time.time()
    train_iter_src = enumerate(train_dataloader_src)
    train_iter_trg = enumerate(train_dataloader_trg)
    train_iter_brg = enumerate(train_dataloader_brg)
    for iteration in range(start_iteration, max_iteration):
        # fetch data_batches for source & target & bridge
        _, data_batch_src = train_iter_src.__next__()
        _, data_batch_trg = train_iter_trg.__next__()
        _, data_batch_brg = train_iter_brg.__next__()
        data_time = time.time() - end

        # copy data from cpu to gpu
        if 'SCN' in cfg.DATASET_SOURCE.TYPE and 'SCN' in cfg.DATASET_TARGET.TYPE:
            # source
            data_batch_src['seg_label'] = data_batch_src['seg_label'].cuda()
            data_batch_src['img'] = data_batch_src['img'].cuda()
            # target
            data_batch_trg['x'][1] = data_batch_trg['x'][1].cuda()
            # bridge
            data_batch_brg['x'][1] = data_batch_brg['x'][1].cuda()
            data_batch_brg['img'] = data_batch_brg['img'].cuda()
        else:
            raise NotImplementedError('Only SCN is supported for now.')

        optimizer_2d.zero_grad()
        optimizer_3d.zero_grad()
        optimizer_FeatureMapping_2d.zero_grad()
        optimizer_FeatureMapping_3d.zero_grad()
        
        # ---------------------------------------------------------------------------- #
        # Train on 2D
        # ---------------------------------------------------------------------------- #
        # 2D EMA modal
        preds_2d_ema = model_2d_ema(data_batch_brg)
        pseudo_labels = torch.argmax(preds_2d_ema['seg_logit'], dim=1)

        # 2D modal
        preds_2d_src = model_2d(data_batch_src)
        preds_2d_brg = model_2d(data_batch_brg)
        mapped_2d_src = model_FeatureMapping_2d(preds_2d_src['feats'])
        mapped_2d_brg = model_FeatureMapping_2d(preds_2d_brg['feats'])

        # segmentation loss: cross entropy
        seg_loss_2d = F.cross_entropy(preds_2d_src['seg_logit'], data_batch_src['seg_label'], weight=class_weights)

        train_metric_logger.update(seg_loss_2d=seg_loss_2d)
        loss_2d = seg_loss_2d

        # update metric (e.g. IoU)
        with torch.no_grad():
            train_metric_2d.update_dict(preds_2d_src, data_batch_src)

        # ---------------------------------------------------------------------------- #
        # Train on 3D
        # ---------------------------------------------------------------------------- #
        # 3D modal
        preds_3d_brg = model_3d(data_batch_brg)
        preds_3d_trg = model_3d(data_batch_trg)

        labels_3d_trg = torch.argmax(F.softmax(preds_3d_trg['seg_logit'], dim=1), dim=1)
        
        mapped_3d_brg = model_FeatureMapping_3d(preds_3d_brg['feats'])
        mapped_3d_trg = model_FeatureMapping_3d(preds_3d_trg['feats'])

        # consistency loss
        l2_norm_2d = sum(p.pow(2.0).sum() for p in model_FeatureMapping_2d.parameters())
        l2_norm_3d = sum(p.pow(2.0).sum() for p in model_FeatureMapping_3d.parameters())        
        consistency_loss = Compute_loss_c(mapped_2d_brg, mapped_3d_brg) + cfg.TRAIN.LSB.lambda_l2 * (l2_norm_2d + l2_norm_3d)

        # segmentation loss: KL divergence
        log_probs_3d_brg = F.log_softmax(preds_3d_brg['seg_logit'], dim=1)
        probs_2d_3d = F.softmax(preds_2d_ema['seg_logit'], dim=1)
        seg_loss_3d = F.kl_div(log_probs_3d_brg, probs_2d_3d, reduction='none').sum(1).mean()

        # alignment loss:
        alignment_loss = Compute_loss_a(mapped_2d_src,  mapped_3d_trg, data_batch_src['seg_label'], labels_3d_trg, num_classes=cfg.MODEL_2D.NUM_CLASSES)

        train_metric_logger.update(seg_loss_3d=seg_loss_3d, alignment_loss=alignment_loss, consistency_loss=consistency_loss)
        loss_3d = seg_loss_3d + cfg.TRAIN.LSB.lambda_a * alignment_loss + cfg.TRAIN.LSB.lambda_c* consistency_loss

        # update metric (e.g. IoU)
        with torch.no_grad():
            train_metric_3d.update_dict(preds_3d_brg, pseudo_labels, pseudo_label=True)
        # backward
        loss_total = loss_2d + loss_3d
        loss_total.backward()

        # update student1   
        optimizer_2d.step()
        optimizer_3d.step()
        optimizer_FeatureMapping_2d.step()
        optimizer_FeatureMapping_3d.step()
        # update teacher
        cur_iter = iteration + 1
        update_ema_variables(model_2d, model_2d_ema, 0.999, cur_iter)

        batch_time = time.time() - end
        train_metric_logger.update(time=batch_time, data=data_time)

        # log
        if cur_iter == 1 or (cfg.TRAIN.LOG_PERIOD > 0 and cur_iter % cfg.TRAIN.LOG_PERIOD == 0):
            logger.info(
                train_metric_logger.delimiter.join(
                    [
                        'iter: {iter:4d}',
                        '{meters}',
                        'lr: {lr:.2e}',
                        'max mem: {memory:.0f}',
                    ]
                ).format(
                    iter=cur_iter,
                    meters=str(train_metric_logger),
                    lr=optimizer_2d.param_groups[0]['lr'],
                    memory=torch.cuda.max_memory_allocated() / (1024.0 ** 2),
                )
            )

        # summary
        if summary_writer is not None and cfg.TRAIN.SUMMARY_PERIOD > 0 and cur_iter % cfg.TRAIN.SUMMARY_PERIOD == 0:
            keywords = ('loss', 'acc', 'iou')
            for name, meter in train_metric_logger.meters.items():
                if all(k not in name for k in keywords):
                    continue
                summary_writer.add_scalar('train/' + name, meter.avg, global_step=cur_iter)

        # ---------------------------------------------------------------------------- #
        # validate for one epoch
        # ---------------------------------------------------------------------------- #
        if val_period > 0 and (cur_iter % val_period == 0 or cur_iter == max_iteration):
            start_time_val = time.time()
            setup_validate()

            validate_3d(cfg,
                     model_3d,
                     val_dataloader,
                     val_metric_logger)

            epoch_time_val = time.time() - start_time_val
            logger.info('Iteration[{}]-Val {}  total_time: {:.2f}s'.format(
                cur_iter, val_metric_logger.summary_str, epoch_time_val))

            # summary
            if summary_writer is not None:
                keywords = ('loss', 'acc', 'iou')
                for name, meter in val_metric_logger.meters.items():
                    if all(k not in name for k in keywords):
                        continue
                    summary_writer.add_scalar('val/' + name, meter.avg, global_step=cur_iter)

            cur_metric_name = cfg.VAL.METRIC_3d
            if cur_metric_name in val_metric_logger.meters:
                cur_metric = val_metric_logger.meters[cur_metric_name].global_avg
                if best_metric is None or best_metric < cur_metric:
                    best_metric = cur_metric
                    best_metric_iter = cur_iter
                    checkpoint_data_3d['iteration'] = cur_iter
                    checkpoint_data_3d[best_metric_name] = best_metric
                    checkpointer_3d.save('model_3d_best', **checkpoint_data_3d)
            logger.info('Best val-{} = {:.2f} at iteration {}'.format(cfg.VAL.METRIC_3d, best_metric * 100, best_metric_iter))
            logger.info(log_weights(model_FeatureMapping_2d))
            logger.info(log_weights(model_FeatureMapping_3d))
            # restore training
            setup_train()

        scheduler_2d.step()
        scheduler_3d.step()
        scheduler_FeatureMapping_2d.step()
        scheduler_FeatureMapping_3d.step()
        end = time.time()

    logger.info('Best val-{} = {:.2f} at iteration {}'.format(cfg.VAL.METRIC_3d,
                                                                 best_metric * 100,
                                                                 best_metric_iter))
def main():
    args = parse_args()
    from common.config import purge_cfg
    from config.LSB import cfg
    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    purge_cfg(cfg)
    cfg.freeze()

    output_dir = args.output_dir

    if output_dir:
        if osp.isdir(output_dir):
            warnings.warn('Output directory exists.')
        os.makedirs(output_dir, exist_ok=True)

    # run name
    timestamp = time.strftime('%m-%d_%H-%M-%S')
    hostname = socket.gethostname()
    run_name = '{:s}.{:s}'.format(timestamp, hostname)

    logger = setup_logger('LSB', output_dir, comment='train.{:s}'.format(run_name))
    logger.info('{:d} GPUs available'.format(torch.cuda.device_count()))
    logger.info(args)

    logger.info('Loaded configuration file {:s}'.format(args.config_file))
    logger.info('Running with config:\n{}'.format(cfg))

    train(cfg, output_dir, run_name)


if __name__ == '__main__':
    main()
