import argparse
import datetime
import gorilla
import os
import os.path as osp
import shutil
import time
import torch
from tensorboardX import SummaryWriter
from tqdm import tqdm

from maft.dataset import build_dataloader, build_dataset
from maft.evaluation import ScanNetEval
from maft.utils import AverageMeter, get_root_logger, rle_decode, write_obj#colors_cityscapes, rle_decode, write_obj
import numpy as np
import torch.nn as nn
from transformers import Trainer, TrainingArguments, TrainerCallback
import torch.distributed as dist
import hydra
import numpy as np
import wandb
import yaml
import torch
from easydict import EasyDict
from hydra.utils import to_absolute_path
from omegaconf import OmegaConf
import multiprocessing


from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist

class MyTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):

        loss, log_vars = model(inputs, mode='loss')

        return (loss[0], None) if return_outputs else loss[0]


def get_args():
    parser = argparse.ArgumentParser('SPFormer')
    parser.add_argument('config', type=str, help='path to config file')
    parser.add_argument('--resume', type=str, help='path to resume from')
    parser.add_argument('--work_dir', type=str, help='working directory')
    parser.add_argument('--skip_validate', action='store_true', help='skip validation')
    parser.add_argument('--eval_only', action='store_true', help='skip validation')
    args = parser.parse_args()
    return args

def update_ema_variables(model, ema_model, alpha, global_step):
    # Use the true average until the exponential average is more correct
    alpha = min(1 - 1 / (global_step + 1), alpha)
    for ema_param, param in zip(ema_model.parameters(), model.parameters()):
        ema_param.data.mul_(alpha).add_(1 - alpha, param.data)

def get_model(cfg, model_name, device_id = None):
    if model_name == 'MAFT':
        from maft.model import Detector
        model = Detector(cfg.model, cfg.train.pretrain)
        return model
    elif model_name == 'SPFormer':
        from maft.model import Detector_SPFormer
        model = Detector_SPFormer(cfg.model, cfg.train.pretrain)
        return model
    else:
        raise NotImplementedError()

def main():


    args = get_args()

    if "WORLD_SIZE" in os.environ and int(os.environ["WORLD_SIZE"]) > 1:
        dist.init_process_group("nccl")
        local_rank = int(os.environ.get("LOCAL_RANK", 0))
        torch.cuda.set_device(local_rank)
        rank = dist.get_rank()
        print(f"Start running basic DDP example on rank {rank}.")
        # create model and move it to GPU with id rank
        device_id = rank % torch.cuda.device_count()

    cfg = gorilla.Config.fromfile(args.config)
    
    if args.work_dir:
        cfg.work_dir = args.work_dir
    else:
        cfg.work_dir = osp.join('./exps', osp.splitext(osp.basename(args.config))[0])
    os.makedirs(osp.abspath(cfg.work_dir), exist_ok=True)
    timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
    log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
    logger = get_root_logger(log_file=log_file)
    #logger.info(f'config: {}')
    shutil.copy(args.config, osp.join(cfg.work_dir, osp.basename(args.config)))
    writer = SummaryWriter(cfg.work_dir)

    # seed
    gorilla.set_random_seed(cfg.train.seed)

    logger.info(cfg)

    # model
    model_name = cfg.model.pop("name", "MAFT")
    model = get_model(cfg, model_name)


    cfg.model_name = model_name

    logger.info(model)

    count_parameters = gorilla.parameter_count(model)['']
    logger.info(f'Parameters: {count_parameters / 1e6:.2f}M')


    # pretrain or resume
    start_epoch = 1

    if args.resume:
        logger.info(f'Resume from {args.resume}')
        model.load_state_dict(torch.load(args.resume))#, strict=False)
        
    if args.eval_only:
        val_dataset = build_dataset(cfg.data.val, logger)
        val_loader = build_dataloader(val_dataset, **cfg.dataloader.val)
        eval_res = eval(0, model.cuda(), val_loader, cfg, logger, writer)
        exit()
    # train and val dataset
    train_dataset = build_dataset(cfg.data.train, logger)
    steps_per_epoch = (len(train_dataset) // (cfg.dataloader.train.batch_size * dist.get_world_size()))
    save_steps = steps_per_epoch * 2

    optimizer = torch.optim.AdamW(model.parameters(), lr = cfg.optimizer.lr, weight_decay = cfg.optimizer.weight_decay)
    lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=cfg.optimizer.lr, steps_per_epoch=steps_per_epoch, epochs=cfg.train.epochs)

    training_args = TrainingArguments(
        output_dir=args.work_dir,
        overwrite_output_dir=True,
        num_train_epochs=cfg.train.epochs,
        per_device_train_batch_size=cfg.dataloader.train.batch_size,
        save_strategy="steps",
        save_steps=save_steps,
        eval_strategy="no",
        logging_steps=steps_per_epoch//3,
        ddp_find_unused_parameters=True,
        remove_unused_columns=False,
        disable_tqdm=False,
        save_safetensors=False,
        save_total_limit=5,
        dataloader_num_workers= 8
    )

    class MyCallback(TrainerCallback):

        def __init__(self, update_ema_variables):
            self.update_ema_variables= update_ema_variables
    
        def on_step_end(self, args, state, control, **kwargs):
            model = kwargs["model"]
            self.update_ema_variables(model.detector, model.ema_detector, 0.999, state.global_step)

    trainer = MyTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        data_collator=train_dataset.collate_fn,
        optimizers=(optimizer, lr_scheduler),
        callbacks = [MyCallback(update_ema_variables = update_ema_variables)],
    )

    trainer.train()#resume_from_checkpoint=args.resume)


if __name__ == '__main__':
    main()
