#!/usr/bin/env python
import os
import os.path as osp
import argparse
import logging
import time
import socket
import warnings

import torch
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter

from common.solver.build import build_optimizer, build_scheduler
from common.utils.checkpoint import CheckpointerV2
from common.utils.logger import setup_logger
from common.utils.metric_logger import MetricLogger
from common.utils.torch_util import set_random_seed
from models.build import build_model_2d, build_model_3d
from data.build import build_dataloader
from data.utils.validate import validate_source_only_2d, validate_source_only_3d


def parse_args():
    parser = argparse.ArgumentParser(description='source_only training')
    parser.add_argument(
        '--cfg',
        dest='config_file',
        default='',
        metavar='FILE',
        help='path to config file',
        type=str,        
    )
    parser.add_argument(
        '--output_dir',
        dest='output_dir',
        default='',
        metavar='DIR',
        help='output directory',
        type=str,
    )
    parser.add_argument(
        'opts',
        help='Modify config options using the command-line',
        default=None,
        nargs=argparse.REMAINDER,
    )
    args = parser.parse_args()
    return args

def init_metric_logger(metric_list):
    new_metric_list = []
    for metric in metric_list:
        if isinstance(metric, (list, tuple)):
            new_metric_list.extend(metric)
        else:
            new_metric_list.append(metric)
    metric_logger = MetricLogger(delimiter='  ')
    metric_logger.add_meters(new_metric_list)
    return metric_logger

def train_2d(cfg, output_dir='', run_name=''):
    # ---------------------------------------------------------------------------- #
    # Build models, optimizer, scheduler, checkpointer, etc.
    # ---------------------------------------------------------------------------- #
    logger = logging.getLogger('source_only_2d.train')
    set_random_seed(cfg.RNG_SEED)

    # build 2d model
    model_2d, train_metric_2d = build_model_2d(cfg)
    logger.info('Build 2D model:\n{}'.format(str(model_2d)))
    num_params = sum(param.numel() for param in model_2d.parameters())
    print('#Parameters: {:.2e}'.format(num_params))

    model_2d = model_2d.cuda()
 
    # build optimizer
    optimizer_2d = build_optimizer(cfg, model_2d)
    
    # build lr scheduler
    scheduler_2d = build_scheduler(cfg, optimizer_2d)

    # build checkpointer
    # Note that checkpointer will load state_dict of model, optimizer and scheduler.
    checkpointer_2d = CheckpointerV2(model_2d,
                                     optimizer=optimizer_2d,
                                     scheduler=scheduler_2d,
                                     save_dir=output_dir,
                                     logger=logger,
                                     postfix='_2d',
                                     max_to_keep=cfg.TRAIN.MAX_TO_KEEP)
    checkpoint_data_2d = checkpointer_2d.load(cfg.RESUME_PATH, resume=cfg.AUTO_RESUME, resume_states=cfg.RESUME_STATES)

    ckpt_period = cfg.TRAIN.CHECKPOINT_PERIOD

    # build tensorboard logger (optionally by comment)
    if output_dir:
        tb_dir = osp.join(output_dir, 'tb.{:s}'.format(run_name))
        summary_writer = SummaryWriter(tb_dir)
    else:
        summary_writer = None

    # ---------------------------------------------------------------------------- #
    # Train
    # ---------------------------------------------------------------------------- #
    max_iteration = cfg.SCHEDULER.MAX_ITERATION
    start_iteration = checkpoint_data_2d.get('iteration', 0)

    # build data loader
    # Reset the random seed again in case the initialization of models changes the random state.
    set_random_seed(cfg.RNG_SEED)
    train_dataloader_src = build_dataloader(cfg, mode='train', domain='source', start_iteration=start_iteration)
    val_period = cfg.VAL.PERIOD
    val_dataloader = build_dataloader(cfg, mode='val', domain='bridge') if val_period > 0 else None

    best_metric_name = 'best_{}'.format(cfg.VAL.METRIC_2d)
    best_metric = checkpoint_data_2d.get(best_metric_name, None)
    best_metric_iter = -1
    logger.info('Start training from iteration {}'.format(start_iteration))

    # add metrics
    train_metric_logger = init_metric_logger([train_metric_2d])
    val_metric_logger = MetricLogger(delimiter='  ')

    def setup_train():
        # set training mode
        model_2d.train()
        # reset metric
        train_metric_logger.reset()

    def setup_validate():
        # set evaluate mode
        model_2d.eval()
        # reset metric
        val_metric_logger.reset()

    if cfg.TRAIN.CLASS_WEIGHTS:
        class_weights = torch.tensor(cfg.TRAIN.CLASS_WEIGHTS).cuda()
    else:
        class_weights = None


    setup_train()
    end = time.time()
    train_iter_src = enumerate(train_dataloader_src)

    for iteration in range(start_iteration, max_iteration):
        # fetch data_batches for source & target & bridge
        _, data_batch_src = train_iter_src.__next__()
        data_time = time.time() - end

        # copy data from cpu to gpu
        if 'SCN' in cfg.DATASET_SOURCE.TYPE and 'SCN' in cfg.DATASET_TARGET.TYPE:
            # source
            data_batch_src['seg_label'] = data_batch_src['seg_label'].cuda()
            data_batch_src['img'] = data_batch_src['img'].cuda()
        else:
            raise NotImplementedError('Only SCN is supported for now.')

        optimizer_2d.zero_grad()
        
        # ---------------------------------------------------------------------------- #
        # Train on source
        # ---------------------------------------------------------------------------- #
        preds_2d_src = model_2d(data_batch_src)

        # segmentation loss: cross entropy
        seg_loss_2d = F.cross_entropy(preds_2d_src['seg_logit'], data_batch_src['seg_label'], weight=class_weights)

        train_metric_logger.update(seg_loss_2d=seg_loss_2d)
        loss_2d = seg_loss_2d

        # update metric (e.g. IoU)
        with torch.no_grad():
            train_metric_2d.update_dict(preds_2d_src, data_batch_src)

        # backward
        loss_total = loss_2d
        loss_total.backward()

        # update student  
        optimizer_2d.step()

        cur_iter = iteration + 1

        batch_time = time.time() - end
        train_metric_logger.update(time=batch_time, data=data_time)

        # log
        if cur_iter == 1 or (cfg.TRAIN.LOG_PERIOD > 0 and cur_iter % cfg.TRAIN.LOG_PERIOD == 0):
            logger.info(
                train_metric_logger.delimiter.join(
                    [
                        'iter: {iter:4d}',
                        '{meters}',
                        'lr: {lr:.2e}',
                        'max mem: {memory:.0f}',
                    ]
                ).format(
                    iter=cur_iter,
                    meters=str(train_metric_logger),
                    lr=optimizer_2d.param_groups[0]['lr'],
                    memory=torch.cuda.max_memory_allocated() / (1024.0 ** 2),
                )
            )

        # summary
        if summary_writer is not None and cfg.TRAIN.SUMMARY_PERIOD > 0 and cur_iter % cfg.TRAIN.SUMMARY_PERIOD == 0:
            keywords = ('loss', 'acc', 'iou')
            for name, meter in train_metric_logger.meters.items():
                if all(k not in name for k in keywords):
                    continue
                summary_writer.add_scalar('train/' + name, meter.avg, global_step=cur_iter)

        # checkpoint
        if (ckpt_period > 0 and cur_iter % ckpt_period == 0) or cur_iter == max_iteration:
            checkpoint_data_2d['iteration'] = cur_iter
            checkpoint_data_2d[best_metric_name] = best_metric
            checkpointer_2d.save('model_2d_{:06d}'.format(cur_iter), **checkpoint_data_2d)
        # ---------------------------------------------------------------------------- #
        # validate for one epoch
        # ---------------------------------------------------------------------------- #
        if val_period > 0 and (cur_iter % val_period == 0 or cur_iter == max_iteration):
            start_time_val = time.time()
            setup_validate()

            validate_source_only_2d(cfg,
                     model_2d,
                     val_dataloader,
                     val_metric_logger)

            epoch_time_val = time.time() - start_time_val
            logger.info('Iteration[{}]-Val {}  total_time: {:.2f}s'.format(
                cur_iter, val_metric_logger.summary_str, epoch_time_val))

            # summary
            if summary_writer is not None:
                keywords = ('loss', 'acc', 'iou')
                for name, meter in val_metric_logger.meters.items():
                    if all(k not in name for k in keywords):
                        continue
                    summary_writer.add_scalar('val/' + name, meter.avg, global_step=cur_iter)

            cur_metric_name = cfg.VAL.METRIC_2d
            if cur_metric_name in val_metric_logger.meters:
                cur_metric = val_metric_logger.meters[cur_metric_name].global_avg
                if best_metric is None or best_metric < cur_metric:
                    best_metric = cur_metric
                    best_metric_iter = cur_iter
            logger.info('Best val-{} = {:.2f} at iteration {}'.format(cfg.VAL.METRIC_2d, best_metric * 100, best_metric_iter))
            # restore training
            setup_train()

        scheduler_2d.step()
        end = time.time()

    logger.info('Best val-{} = {:.2f} at iteration {}'.format(cfg.VAL.METRIC_2d,best_metric * 100,best_metric_iter))
    return best_metric_iter

def train_3d(cfg, best_metric_iter, output_dir='', checkpoint_dir_2d='', run_name=''):
    # ---------------------------------------------------------------------------- #
    # Build models, optimizer, scheduler, checkpointer, etc.
    # ---------------------------------------------------------------------------- #
    logger = logging.getLogger('source_only_3d.train')

    set_random_seed(cfg.RNG_SEED)

    # build 2d model
    model_2d, _ = build_model_2d(cfg)
    model_2d = model_2d.cuda()
    checkpoint_path = checkpoint_dir_2d + f"/model_2d_{best_metric_iter:06d}.pth"
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model_2d.load_state_dict(checkpoint['model'])


    # build 3d model
    model_3d, train_metric_3d = build_model_3d(cfg)
    logger.info('Build 3D model:\n{}'.format(str(model_3d)))
    num_params = sum(param.numel() for param in model_3d.parameters())
    print('#Parameters: {:.2e}'.format(num_params))

    model_3d = model_3d.cuda()

    # build optimizer
    optimizer_3d = build_optimizer(cfg, model_3d)

    # build lr scheduler
    scheduler_3d = build_scheduler(cfg, optimizer_3d)   

    # build checkpointer
    # Note that checkpointer will load state_dict of model, optimizer and scheduler.
    checkpointer_3d = CheckpointerV2(model_3d,
                                     optimizer=optimizer_3d,
                                     scheduler=scheduler_3d,
                                     save_dir=output_dir,
                                     logger=logger,
                                     postfix='_3d',
                                     max_to_keep=cfg.TRAIN.MAX_TO_KEEP)
    checkpoint_data_3d = checkpointer_3d.load(cfg.RESUME_PATH, resume=cfg.AUTO_RESUME, resume_states=cfg.RESUME_STATES)

    ckpt_period = cfg.TRAIN.CHECKPOINT_PERIOD

    # build tensorboard logger (optionally by comment)
    if output_dir:
        tb_dir = osp.join(output_dir, 'tb.{:s}'.format(run_name))
        summary_writer = SummaryWriter(tb_dir)
    else:
        summary_writer = None

    # ---------------------------------------------------------------------------- #
    # Train
    # ---------------------------------------------------------------------------- #
    max_iteration = cfg.SCHEDULER.MAX_ITERATION
    start_iteration = checkpoint_data_3d.get('iteration', 0)

    # build data loader
    # Reset the random seed again in case the initialization of models changes the random state.
    set_random_seed(cfg.RNG_SEED)
    train_dataloader_brg = build_dataloader(cfg, mode='train', domain='bridge', start_iteration=start_iteration)

    val_period = cfg.VAL.PERIOD
    val_dataloader = build_dataloader(cfg, mode='val', domain='target') if val_period > 0 else None

    best_metric_name = 'best_{}'.format(cfg.VAL.METRIC_3d)
    best_metric = checkpoint_data_3d.get(best_metric_name, None)
    best_metric_iter = -1
    logger.info('Start training from iteration {}'.format(start_iteration))

    # add metrics
    train_metric_logger = init_metric_logger([train_metric_3d])
    val_metric_logger = MetricLogger(delimiter='  ')

    def setup_train():
        # set training mode
        model_3d.train()

        # reset metric
        train_metric_logger.reset()

    def setup_validate():
        # set evaluate mode
        model_3d.eval()
        # reset metric
        val_metric_logger.reset()

    if cfg.TRAIN.CLASS_WEIGHTS:
        class_weights = torch.tensor(cfg.TRAIN.CLASS_WEIGHTS).cuda()
    else:
        class_weights = None


    setup_train()
    model_2d.eval()
    end = time.time()
    train_iter_brg = enumerate(train_dataloader_brg)
    for iteration in range(start_iteration, max_iteration):
        # fetch data_batches for source & target & bridge
        _, data_batch_brg = train_iter_brg.__next__()
        data_time = time.time() - end

        # copy data from cpu to gpu
        if 'SCN' in cfg.DATASET_BRIDGE.TYPE and 'SCN' in cfg.DATASET_TARGET.TYPE:    
            # bridge
            data_batch_brg['img'] = data_batch_brg['img'].cuda()
            data_batch_brg['x'][1] = data_batch_brg['x'][1].cuda()
            data_batch_brg['seg_label'] = data_batch_brg['seg_label'].cuda()
        else:
            raise NotImplementedError('Only SCN is supported for now.')

        optimizer_3d.zero_grad()

        # ---------------------------------------------------------------------------- #
        # Train on bridge
        # ---------------------------------------------------------------------------- #
        # student modal
        preds_3d_brg = model_3d(data_batch_brg)
        # pseudo_label
        with torch.no_grad():
            preds_2d = model_2d(data_batch_brg)
            pseudo_label = torch.argmax(preds_2d['seg_logit'], dim=1)
        
        correct_predictions = (pseudo_label == data_batch_brg['seg_label']).sum().item()
        total_samples = data_batch_brg['seg_label'].numel()
        brg_accuracy = round((correct_predictions / total_samples) * 100, 2)
        
        # segmentation loss
        seg_loss_3d = F.cross_entropy(preds_3d_brg['seg_logit'], pseudo_label, ignore_index=-100)#, weight=class_weights)

        train_metric_logger.update(brg_accuracy = brg_accuracy, seg_loss_3d=seg_loss_3d)

        # update metric (e.g. IoU)
        with torch.no_grad():
            train_metric_3d.update_dict_3d(preds_3d_brg, pseudo_label)
        # backward
        loss_total = seg_loss_3d
        loss_total.backward()

        # update student1   
        optimizer_3d.step()

        cur_iter = iteration + 1

        batch_time = time.time() - end
        train_metric_logger.update(time=batch_time, data=data_time)

        # log
        if cur_iter == 1 or (cfg.TRAIN.LOG_PERIOD > 0 and cur_iter % cfg.TRAIN.LOG_PERIOD == 0):
            logger.info(
                train_metric_logger.delimiter.join(
                    [
                        'iter: {iter:4d}',
                        '{meters}',
                        'lr: {lr:.2e}',
                        'max mem: {memory:.0f}',
                    ]
                ).format(
                    iter=cur_iter,
                    meters=str(train_metric_logger),
                    lr=optimizer_3d.param_groups[0]['lr'],
                    memory=torch.cuda.max_memory_allocated() / (1024.0 ** 2),
                )
            )

        # summary
        if summary_writer is not None and cfg.TRAIN.SUMMARY_PERIOD > 0 and cur_iter % cfg.TRAIN.SUMMARY_PERIOD == 0:
            keywords = ('loss', 'acc', 'iou')
            for name, meter in train_metric_logger.meters.items():
                if all(k not in name for k in keywords):
                    continue
                summary_writer.add_scalar('train/' + name, meter.avg, global_step=cur_iter)

        # checkpoint
        if (ckpt_period > 0 and cur_iter % ckpt_period == 0) or cur_iter == max_iteration:           
            checkpoint_data_3d['iteration'] = cur_iter
            checkpoint_data_3d[best_metric_name] = best_metric
            checkpointer_3d.save('model_3d_{:06d}'.format(cur_iter), **checkpoint_data_3d)
        # ---------------------------------------------------------------------------- #
        # validate for one epoch
        # ---------------------------------------------------------------------------- #
        if val_period > 0 and (cur_iter % val_period == 0 or cur_iter == max_iteration):
            start_time_val = time.time()
            setup_validate()

            validate_source_only_3d(cfg,
                     model_3d,
                     val_dataloader,
                     val_metric_logger)

            epoch_time_val = time.time() - start_time_val
            logger.info('Iteration[{}]-Val {}  total_time: {:.2f}s'.format(
                cur_iter, val_metric_logger.summary_str, epoch_time_val))

            # summary
            if summary_writer is not None:
                keywords = ('loss', 'acc', 'iou')
                for name, meter in val_metric_logger.meters.items():
                    if all(k not in name for k in keywords):
                        continue
                    summary_writer.add_scalar('val/' + name, meter.avg, global_step=cur_iter)

            cur_metric_name = cfg.VAL.METRIC_3d
            if cur_metric_name in val_metric_logger.meters:
                cur_metric = val_metric_logger.meters[cur_metric_name].global_avg
                if best_metric is None or best_metric < cur_metric:
                    best_metric = cur_metric
                    best_metric_iter = cur_iter
            logger.info('Best val-{} = {:.2f} at iteration {}'.format(cfg.VAL.METRIC_3d, best_metric * 100, best_metric_iter))
            # restore training
            setup_train()

        scheduler_3d.step()
        end = time.time()

    logger.info('Best val-{} = {:.2f} at iteration {}'.format(cfg.VAL.METRIC_3d,
                                                                 best_metric * 100,
                                                                 best_metric_iter))

def main():
    args = parse_args()
    from common.config import purge_cfg
    from config.LSB import cfg
    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    purge_cfg(cfg)
    cfg.freeze()
#-------------------------------train on 2d----------------------------------------------------------------
    output_dir_2d = args.output_dir + '/source_only_2d'
    if output_dir_2d:
        if osp.isdir(output_dir_2d):
            warnings.warn('Output directory exists.')
        os.makedirs(output_dir_2d, exist_ok=True)

    # run name
    timestamp = time.strftime('%m-%d_%H-%M-%S')
    hostname = socket.gethostname()
    run_name = '{:s}.{:s}'.format(timestamp, hostname)

    logger = setup_logger('source_only_2d', output_dir_2d, comment='train.{:s}'.format(run_name))
    logger.info('{:d} GPUs available'.format(torch.cuda.device_count()))
    logger.info(args)

    logger.info('Loaded configuration file {:s}'.format(args.config_file))
    logger.info('Running with config:\n{}'.format(cfg))
    

    best_metric_iter = train_2d(cfg, output_dir_2d, run_name)
#-------------------------------train on 3d----------------------------------------------------------------
    output_dir_3d = args.output_dir + '/source_only_3d'

    if output_dir_3d:
        if osp.isdir(output_dir_3d):
            warnings.warn('Output directory exists.')
        os.makedirs(output_dir_3d, exist_ok=True)

    # run name
    timestamp = time.strftime('%m-%d_%H-%M-%S')
    hostname = socket.gethostname()
    run_name = '{:s}.{:s}'.format(timestamp, hostname)

    logger = setup_logger('source_only_3d', output_dir_3d, comment='train.{:s}'.format(run_name))
    logger.info('{:d} GPUs available'.format(torch.cuda.device_count()))
    logger.info(args)

    logger.info('Loaded configuration file {:s}'.format(args.config_file))
    logger.info('Running with config:\n{}'.format(cfg))

    train_3d(cfg, 30000, output_dir_3d, output_dir_2d, run_name)


if __name__ == '__main__':
    main()