import argparse
import datetime
import gorilla
import os
import os.path as osp
import shutil
import time
import torch
from tensorboardX import SummaryWriter
from tqdm import tqdm

from maft.dataset import build_dataloader, build_dataset
from maft.evaluation import ScanNetEval
from maft.utils import AverageMeter, get_root_logger, rle_decode, write_obj#colors_cityscapes, rle_decode, write_obj
import numpy as np
import torch.nn as nn
from transformers import Trainer, TrainingArguments
import torch.distributed as dist
import hydra
import numpy as np
import wandb
import yaml
import torch
from easydict import EasyDict
from hydra.utils import to_absolute_path
from omegaconf import OmegaConf
import multiprocessing


from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist

class MyTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):

        loss, log_vars = model(inputs, mode='loss')

        return (loss, None) if return_outputs else loss

def get_args():
    parser = argparse.ArgumentParser('SPFormer')
    parser.add_argument('config', type=str, help='path to config file')
    parser.add_argument('--resume', type=str, help='path to resume from')
    parser.add_argument('--work_dir', type=str, help='working directory')
    parser.add_argument('--skip_validate', action='store_true', help='skip validation')
    parser.add_argument('--eval_only', action='store_true', help='skip validation')
    args = parser.parse_args()
    return args

def update_ema_variables(model, ema_model, alpha, global_step):
    # Use the true average until the exponential average is more correct
    alpha = min(1 - 1 / (global_step + 1), alpha)
    for ema_param, param in zip(ema_model.parameters(), model.parameters()):
        ema_param.data.mul_(alpha).add_(1 - alpha, param.data)

def train(epoch, model, dataloader, optimizer, lr_scheduler, cfg, logger, writer, global_step):#, rank):
    model.train()
    iter_time = AverageMeter()
    data_time = AverageMeter()
    meter_dict = {}
    end = time.time()

    epoch_time = 0

    for i, batch in tqdm(enumerate(dataloader, start=1)):
        data_time.update(time.time() - end)

        if cfg.train.get("append_epoch", False):
            batch['epoch'] = epoch

        if cfg.train.get("use_rgb", True) == False:
            batch['feats'] = batch['feats'][:, 3:]

        if cfg.model_name.startswith("SPFormer"):
            batch.pop("coords_float", "")

        if (not cfg.model_name.endswith("no_superpoint")) and cfg.train.get("use_batch_points_offsets", False) == False:
            batch.pop("batch_points_offsets", "")

        loss, log_vars = model(batch, mode='loss')

        # meter_dict
        for k, v in log_vars.items():
            if k not in meter_dict.keys():
                meter_dict[k] = AverageMeter()
            meter_dict[k].update(v)

        # backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # time and print
        remain_iter = len(dataloader) * (cfg.train.epochs - epoch + 1) - i
        iter_time.update(time.time() - end)
        epoch_time += iter_time.val
        end = time.time()
        remain_time = remain_iter * iter_time.avg
        remain_time = str(datetime.timedelta(seconds=int(remain_time)))
        lr = optimizer.param_groups[0]['lr']
        if i % 10 == 0:
            log_str = f'Epoch [{epoch}/{cfg.train.epochs}][{i}/{len(dataloader)}]  '
            log_str += f'lr: {lr:.2g}, eta: {remain_time}, '
            log_str += f'data_time: {data_time.val:.2f}, iter_time: {iter_time.val:.2f}'
            for k, v in meter_dict.items():
                log_str += f', {k}: {v.val:.4f}'
            logger.info(log_str)


        global_step += 1
        update_ema_variables(model.detector, model.ema_detector, 0.999, global_step)

    # update lr
    lr_scheduler.step()
    lr = optimizer.param_groups[0]['lr']

    # log and save
    writer.add_scalar('train/learning_rate', lr, epoch)
    for k, v in meter_dict.items():
        writer.add_scalar(f'train/{k}', v.avg, epoch)
    save_file = osp.join(cfg.work_dir, f'lastest.pth')
    meta = dict(epoch=epoch)
    gorilla.save_checkpoint(model, save_file, optimizer, lr_scheduler, meta)

    return global_step


@torch.no_grad()
def eval(epoch, model, dataloader, cfg, logger, writer):
    logger.info('Validation')
    pred_insts, gt_insts = [], []
    progress_bar = tqdm(total=len(dataloader))
    val_dataset = dataloader.dataset

    model.eval()
    for batch in dataloader:

        result = model(batch, mode='predict')
        pred_insts.append(result['pred_instances'])
        gt_insts.append(result['gt_instances'])
        progress_bar.update()
    progress_bar.close()

    # evaluate
    logger.info('Evaluate instance segmentation')
    scannet_eval = ScanNetEval(val_dataset.CLASSES)
    try:
        eval_res = scannet_eval.evaluate(pred_insts, gt_insts)
        writer.add_scalar('val/AP', eval_res['all_ap'], epoch)
        writer.add_scalar('val/AP_50', eval_res['all_ap_50%'], epoch)
        writer.add_scalar('val/AP_25', eval_res['all_ap_25%'], epoch)
        logger.info('AP: {:.3f}. AP_50: {:.3f}. AP_25: {:.3f}'.format(eval_res['all_ap'], eval_res['all_ap_50%'],
                                                                    eval_res['all_ap_25%']))
    except Exception as e:
        logger.info(str(e))
        eval_res = {'all_ap': 0.0, 'all_ap_50%': 0.0, 'all_ap_25%': 0.0}

    return eval_res

def get_model(cfg, model_name, device_id = None):
    if model_name == 'MAFT':
        from maft.model import Detector
        model = Detector(cfg.model, cfg.train.pretrain)
        return model.cuda()
    elif model_name == 'SPFormer':
        from maft.model import Detector_SPFormer
        model = Detector_SPFormer(cfg.model, cfg.train.pretrain)
        return model
    else:
        raise NotImplementedError()

def main():


    args = get_args()

    cfg = gorilla.Config.fromfile(args.config)
    if args.work_dir:
        cfg.work_dir = args.work_dir
    else:
        cfg.work_dir = osp.join('./exps', osp.splitext(osp.basename(args.config))[0])
    os.makedirs(osp.abspath(cfg.work_dir), exist_ok=True)
    timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
    log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
    logger = get_root_logger(log_file=log_file)
    shutil.copy(args.config, osp.join(cfg.work_dir, osp.basename(args.config)))
    writer = SummaryWriter(cfg.work_dir)

    # seed
    gorilla.set_random_seed(cfg.train.seed)

    logger.info(cfg)

    # model
    model_name = cfg.model.pop("name", "MAFT")
    model = get_model(cfg, model_name)

    cfg.model_name = model_name

    logger.info(model)


    count_parameters = gorilla.parameter_count(model)['']
    logger.info(f'Parameters: {count_parameters / 1e6:.2f}M')
    optimizer = torch.optim.AdamW(model.parameters(), lr = cfg.optimizer.lr, weight_decay = cfg.optimizer.weight_decay)
    lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=cfg.optimizer.lr, steps_per_epoch=steps_per_epoch, epochs=cfg.train.epochs)

    # pretrain or resume
    start_epoch = 1

    if args.resume:
        logger.info(f'Resume from {args.resume}')
        meta = gorilla.resume(model, args.resume, optimizer, lr_scheduler)
        start_epoch = meta['epoch']
        
    # train and val dataset
    train_dataset = build_dataset(cfg.data.train, logger)

    train_loader = build_dataloader(train_dataset, **cfg.dataloader.train)
    if not args.skip_validate:
        val_dataset = build_dataset(cfg.data.val, logger)
        val_loader = build_dataloader(val_dataset, **cfg.dataloader.val)

    # train and val
    logger.info('Training')
    best_AP = 0.0
    save_file = None
    if args.eval_only:
        eval_res = eval(0, model, val_loader, cfg, logger, writer)
        exit()
    global_step = 0
    for epoch in tqdm(range(start_epoch, cfg.train.epochs + 1)):
        global_step = train(epoch, model, train_loader, optimizer, lr_scheduler, cfg, logger, writer, global_step)#, rank)
        if not args.skip_validate and (epoch % cfg.train.interval == 0):
            eval_res = eval(epoch, model, val_loader, cfg, logger, writer)
            if eval_res['all_ap'] > best_AP:
                if save_file is not None:
                    os.remove(save_file)
                best_AP = eval_res['all_ap']
                save_file = osp.join(cfg.work_dir, 'epoch{:03}_AP_{:.4f}_{:.4f}_{:.4f}.pth'.format(epoch, eval_res['all_ap'], eval_res['all_ap_50%'], eval_res['all_ap_25%']))
                meta = dict(epoch=epoch)
                gorilla.save_checkpoint(model, save_file, optimizer, lr_scheduler, meta)

        writer.flush()

if __name__ == '__main__':
    main()
