import argparse
import os
from omegaconf import OmegaConf
from train import initialize, test
from datetime import datetime
from utils.log_utils import get_logger, add_file_handler
import logging
import torch
import wandb
from copy import deepcopy
import time

from utils.evaluator import evaluator_registry, DummyEvaluator, eval_to_print, eval_to_wandb

log = get_logger(__name__, level = logging.INFO)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if __name__ == "__main__":
    # parse 3 arguments
    parser = argparse.ArgumentParser(description='Evaluation arguments')
    parser.add_argument('--group', type=str, help='wandb group')
    parser.add_argument('--dataset', type=str)
    parser.add_argument('--model', type=str)
    parser.add_argument('--timestamp', type=str)
    parser.add_argument('--log_model_dir', type=str)
    parser.add_argument('--wandb', action='store_true', help='To use wandb logging or not')

    args = parser.parse_args()

    fdir = os.path.join(args.log_model_dir, "models", args.group, args.dataset, args.model, args.timestamp)

    cfg_path = os.path.join(fdir, "cfg.yaml")
    model_path = os.path.join(fdir, "ckpt.pt")
    model_state = torch.load(model_path)

    # compose cfg
    cfg = OmegaConf.load(cfg_path)
    # change timestamp
    curr_t = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    cfg.wandb.name = f"FINAL_{cfg.model.wandb_name}_{cfg.dataset.name}_{curr_t}"
    cfg.only_evaluation = True

    #wandb
    if args.wandb:
        log.enable_wandb()
        wandb_config = OmegaConf.to_container(
            cfg, resolve=True, throw_on_missing=True
        )
        if cfg.wandb.tags is not None:
            run = wandb.init(
                project=cfg.wandb.project,
                group=cfg.wandb.group,
                name=cfg.wandb.name,
                config=wandb_config,
                tags=list(cfg.wandb.tags))
        else:
            run = wandb.init(
                project=cfg.wandb.project,
                group=cfg.wandb.group,
                name=cfg.wandb.name,
                config=wandb_config) 

    model, train_loader, test_loader, optimizer, scheduler, train_evaluator, test_evaluator, loss_type = initialize(cfg)

    log.info("Loading pretrained weights")
    model.load_state_dict(model_state)

    log.info("Final evaluation")
    step_final_eval = deepcopy(cfg.step)
    step_final_eval.update(cfg.step.final_eval)
    _, test_time = test(
        step_final_eval, test_loader, model, loss_type, evaluator=test_evaluator)
    log.info(f"Final evaluation time: {test_time:.2f} s")

    test_evaluation = test_evaluator.get_evaluation()
    test_evaluator.reset()

    log.info(eval_to_print(test_evaluation, is_train=False))
    log.wandb(eval_to_wandb(test_evaluation, is_train=False, is_final = True), step=cfg.num_epochs)
    # log.wandb_summary(eval_to_wandb(test_evaluation, is_train=False, is_final = True))

    # wain until wandb uploads
    # time.sleep(10)









