import hydra
import torch
from omegaconf import DictConfig, OmegaConf
from torch import optim as optim
from torch.backends import cudnn as cudnn
from tqdm import tqdm

import eval_util
import util
from meta_labeler import MetaLabeler
from models.util import create_model
from routines import parse_option, get_dataset


@hydra.main(config_path="config", config_name="learn_meta_repr.yaml")
def learn_meta_repr(opt: DictConfig):
    OmegaConf.set_struct(opt, False)
    opt = parse_option(opt)

    opt.test_C = 1/opt.lam
    print(opt.test_C)

    logger = util.get_logger(opt.logger_name, file_name=f"{opt.logger_name}_{opt.model_name}")
    logger.info(opt)

    meta_trainloader, meta_valloader, n_cls = get_dataset(opt)
    backbone = create_model(opt.model, n_cls, dataset=opt.dataset).cuda()

    if torch.cuda.is_available():
        cudnn.benchmark = True

    model = MetaLabeler(backbone, opt, opt.feat_dim, extra_reg=opt.extra_reg).cuda()

    optimizer = optim.SGD(model.parameters(),
                          lr=opt.learning_rate,
                          momentum=opt.momentum,
                          weight_decay=opt.weight_decay)

    avg_metric = util.AverageMeter()

    best = 0
    for epoch in range(opt.epochs):
        for id, batch_data in enumerate(tqdm(meta_trainloader)):
            optimizer.zero_grad()
            task_data = list(map(lambda x: x[0], batch_data))
            support_xs, support_ys, query_xs, query_ys = util.to_cuda_maybe(task_data)
            m_loss, *metrics = model.forward(support_xs, support_ys, query_xs, query_ys)
            m_loss.backward()
            optimizer.step()
            avg_metric.update([m_loss.item()] + list(metrics))

        util.adjust_learning_rate(epoch, opt, optimizer)

        logger.info(f"epoch {epoch}")
        info = util.print_metrics(["m_loss", "m_acc"], avg_metric.avg)
        logger.info(info)
        avg_metric.reset()
        acc = eval_util.meta_test(model, meta_valloader, opt)[0]
        model.train()
        logger.info(f"validation acc: {acc}")
        if acc > best:
            best = acc
            util.save_routine(epoch, model, optimizer, f"{opt.model_path}/{opt.model_name}_meta_best")


if __name__ == '__main__':
    learn_meta_repr()