import logging

import torch

from tqdm import tqdm
from maskrcnn_benchmark.utils.metric_logger import MetricLogger
import json, os
from inference import inference, inference_ocl_attr
from dataloader import get_dataloader
from collections import defaultdict

def ocl_train(model,local_rank,data_loader,val_data_loader,optimizer,scheduler,checkpointer,
        device,checkpoint_period,arguments,writer,epoch):

    logger = logging.getLogger("maskrcnn_benchmark.trainer")
    logger.info("Start training @ epoch {:02d}".format(arguments['epoch']))
    meters = MetricLogger(delimiter="  ")
    model.train()

    text = "GPU {}".format(local_rank)
    pbar = tqdm(
        total=len(data_loader),
        position=local_rank,
        desc=text
    )

    seen_objects = []
    all_objects = list(data_loader.dataset.object_name_to_id.values())
    metric_dict = defaultdict(list)
    total_iter = 0

    for task_id in all_objects[:model.cfg.EXTERNAL.OBJECT_TOP_K]:
        data_loader.dataset.remap_index(objects=[task_id])
        seen_objects.append(task_id)
        task_iter = 0
        for task_iter, out_dict in enumerate(data_loader):
            images = torch.stack(out_dict['images'])
            # targets = out_dict['gt_bboxes']
            obj_labels = torch.cat(out_dict['object_labels'], -1)
            attr_labels = torch.cat(out_dict['attribute_labels'], -1)
            cropped_image = torch.stack(out_dict['cropped_image'])

            arguments["global_step"] += 1
            arguments["iteration"] = total_iter
            total_iter += 1

            images = images.to(device)
            cropped_image = cropped_image.to(device)
            obj_labels = obj_labels.to(device)
            attr_labels = attr_labels.to(device)

            feat = torch.cat([cropped_image.flatten(1), images.flatten(1)], -1)
            ret_dict = model.observe(feat, attr_labels, obj_labels)

            scheduler.step()

            if writer is not None:
                for name, meter in meters.meters.items():
                    writer.add_scalar('Loss/' + name, meter.median, arguments['global_step'])
                writer.add_scalar('LR/learning_rate', optimizer.param_groups[0]['lr'], arguments['global_step'])

            pbar.update(1)
            task_iter += 1
        if task_id % 10 == 9:
            seen_data_loader = get_dataloader(model.cfg, 'val',False,False,filter_obj=seen_objects,
                                              batch_size=1 if model.cfg.EXTERNAL.OCL.ALGO=='HNET' else None)
            obj_acc, attr_acc, inst_num = inference(model, 0, total_iter, 0, seen_data_loader, 'gqa',
                                                    max_instance=-1, mute=False)
            metric_dict['obj_acc'].append(obj_acc)
            metric_dict['attr_acc'].append(attr_acc)
            metric_dict['seen_obj_num'].append(len(seen_objects))
            wf = open(os.path.join(model.cfg.OUTPUT_DIR,'metric_{}_epoch_{}_task_{}.json'.
                                  format(model.cfg.EXTERNAL.BATCH_SIZE * total_iter, epoch, task_id)), 'w')
            json.dump(metric_dict, wf, indent=4)
            wf.close()

        # for demo purpose
        task_name = data_loader.dataset.object_names[task_id]
        if task_name in ['car', 'door', 'fence', 'chair', 'jacket', 'hat', 'shoe', 'plate','flower','bag',
                         'roof','bottle','coat','kite']:
            checkpointer.save('model_{}_{}'.format(task_name, task_id), **arguments)


def draw_curves(model, data_loader, iteration, local_rank):
        val_data_loader_seen = get_dataloader(model.cfg, split='val', distributed=False,
                                     feature_extractor=False, filter_obj=None, use_unseen_obj=False,
                                     use_unseen_combination=False)
        _, attr_f1_1, _ = inference(model, 0, 0, local_rank, val_data_loader_seen, 'gqa', max_instance=3200)

        data_loader_novel_combination = get_dataloader(model.cfg, split='val', distributed=False,
                                     feature_extractor=False, filter_obj=None, use_unseen_obj=False,
                                     use_unseen_combination=True)
        _, attr_f1_2, _ = inference(model, 0, 0, local_rank, data_loader_novel_combination, 'gqa', max_instance=3200)

        data_loader_novel_objects = get_dataloader(model.cfg, split='val', distributed=False,
                                     feature_extractor=False, filter_obj=None, use_unseen_obj=True,
                                     use_unseen_combination=False)
        _, attr_f1_3, _ = inference(model, 0, 0, local_rank, data_loader_novel_objects, 'gqa', max_instance=3200)

        wf = open(os.path.join(model.cfg.OUTPUT_DIR, 'metric_{}.json'.
                               format(model.cfg.EXTERNAL.BATCH_SIZE * iteration)), 'w')
        json.dump({'seen': attr_f1_1, 'novel_combination': attr_f1_2, 'novel_objects': attr_f1_3}, wf)
        wf.close()
