import os

import hydra
import numpy as np
import torch
from torch import nn, optim
from omegaconf import DictConfig, OmegaConf
from torch.backends import cudnn as cudnn
from tqdm import tqdm

import util
import eval_util
from meta_labeler import MetaLabeler, SampleBuffer
from models.util import create_model
from routines import parse_option, get_dataset


@hydra.main(config_path="config", config_name="learn_labeler.yaml")
def learn_labeler_and_model(opt: DictConfig):
    OmegaConf.set_struct(opt, False)
    opt = parse_option(opt)

    logger = util.get_logger(opt.logger_name, file_name=f"{opt.logger_name}_{opt.model_name}")
    logger.info(opt)

    save_name = f"{opt.model_name}_sup_best_labeler_q{opt.std_factor}"

    meta_trainloader, meta_valloader, _ = get_dataset(opt)
    backbone = create_model(opt.model, -1, dataset=opt.dataset).cuda()

    if torch.cuda.is_available():
        cudnn.benchmark = True

    model = MetaLabeler(backbone, opt, opt.feat_dim, intercept=opt.use_bias).cuda()

    labeler = get_labeler(opt, meta_trainloader, logger)

    buffer = SampleBuffer([3, 84, 84], labeler)

    model.classifier = nn.Linear(model.feat_dim, labeler.K)
    model.cuda()
    model.train()

    avg_metric = util.AverageMeter()

    optimizer = optim.SGD(model.parameters(),
                          lr=opt.learning_rate,
                          momentum=opt.momentum,
                          weight_decay=opt.weight_decay)

    best = 0
    for epoch in range(opt.epochs):
        for id, batch_data in enumerate(tqdm(meta_trainloader)):
            optimizer.zero_grad()
            task_data = list(map(lambda x: x[0], batch_data))
            xs = task_data[0]
            if buffer.size() < 5000:
                buffer.add_batch(xs.cuda())

            if buffer.size() >= 5000:
                flat_batch = buffer.sample_and_remove(64)
                sup_loss = model.forward_sup(*flat_batch)[0]
                sup_loss.backward()
                optimizer.step()
                avg_metric.update([sup_loss.item()])
        util.adjust_learning_rate(epoch, opt, optimizer)

        logger.info(f"epoch {epoch}")
        info = util.print_metrics(["sup_loss"], avg_metric.avg)
        logger.info(info)
        avg_metric.reset()
        acc = eval_util.meta_test(model, meta_valloader, opt)[0]
        model.train()
        logger.info(f"validation acc: {acc}")
        if acc > best:
            best = acc
            util.save_routine(epoch, model, optimizer, f"{opt.model_path}/{save_name}")


def get_labeler(opt, meta_trainloader, logger):
    backbone = create_model(opt.model, -1, dataset=opt.dataset).cuda()

    labeler = MetaLabeler(backbone, opt, opt.feat_dim, intercept=opt.use_bias).cuda()

    save_dict = torch.load(os.path.join(opt.model_path, opt.pretrained_model))["model"]
    util.partial_reload(labeler, save_dict)

    labeler.eval()

    # init centroids
    for id, batch_data in enumerate(meta_trainloader):
        if id >= labeler.K / labeler.n_ways:
            break
        task_data = list(map(lambda x: x[0], batch_data))
        xs = task_data[0]

        labeler.init_centroid(xs.cuda())

    prev_k = opt.K + 10
    iter = 0
    while prev_k - labeler.K >= 2:
        iter += 1
        prev_k = labeler.K
        for id, batch_data in enumerate(tqdm(meta_trainloader)):
            task_data = list(map(lambda x: x[0], batch_data))

            xs = task_data[0]
            labeler.cluster_task(xs.cuda())

        labeler.remove_cluster(opt.n_ways, opt.std_factor)
        logger.info(f"No. of clusters: {labeler.K}")

    torch.save(labeler.centroid, f"{opt.model_path}/{opt.model_name}_centroid_q{opt.std_factor}")

    return labeler


if __name__ == '__main__':
    learn_labeler_and_model()