import logging
import torch
import torch.nn.functional as F
from training.dist_utils import all_gather
from tqdm import tqdm
from .distributed import is_master
from open_clip import get_cast_dtype
from .precision import get_autocast


def run(model, dataloader, args):
    cls_embeddings = dataloader.dataset.embeddings
    cls_embeddings = F.normalize(torch.from_numpy(cls_embeddings).float(), dim=-1)
    cls_embeddings = cls_embeddings.to(args.device)
    autocast = get_autocast(args.precision)
    cast_dtype = get_cast_dtype(args.precision)
    if cast_dtype is not None:
        cls_embeddings = cls_embeddings.to(dtype=cast_dtype)
    with torch.no_grad():
        correct_rois = []
        correct_maskpool = []
        correct_crops = []
        similarity_crops = []
        similarity_rois = []
        similarity_maskpool = []
        all_box_sizes = []
        all_is_thing = []
        all_cls_labels = []
        for images, bboxes, image_crops, gt_masks, masked_image_crops \
                in tqdm(dataloader, disable=not is_master(args)):
            images = images.to(args.device)
            bboxes = bboxes.to(args.device)
            image_crops = image_crops.to(args.device)
            masked_image_crops = masked_image_crops.to(args.device)
            gt_masks = gt_masks.to(args.device)
            if cast_dtype is not None:
                images = images.to(dtype=cast_dtype)
                bboxes = bboxes.to(dtype=cast_dtype)
                image_crops = image_crops.to(dtype=cast_dtype)
                masked_image_crops = masked_image_crops.to(dtype=cast_dtype)
                gt_masks = gt_masks.to(dtype=cast_dtype)
            image_crops_list = []
            gt_masks_list = []
            cls_labels = []
            rois = []
            box_sizes = []
            is_thing = []
            for bboxes_per_image, crops_per_image, gt_mask, masked_crops_per_image \
                    in zip(bboxes, image_crops, gt_masks, masked_image_crops):
                valid = bboxes_per_image[:, 5] > 0.5
                rois.append(bboxes_per_image[valid, :4])
                cls_labels.append(bboxes_per_image[valid, 4])
                image_crops_list.append(crops_per_image[valid])
                gt_masks_list.append(gt_mask[valid])
                box_sizes.append(bboxes_per_image[valid, 6])
                is_thing.append(bboxes_per_image[valid, 7])
            cls_labels = torch.cat(cls_labels, dim=0).to(torch.long)
            if cls_labels.shape[0] == 0:
                continue
            image_crops = torch.cat(image_crops_list)
            box_sizes = torch.cat(box_sizes, dim=0).float()
            is_thing = torch.cat(is_thing, dim=0)
            all_box_sizes.append(box_sizes)
            all_is_thing.append(is_thing)
            with autocast():
                # predict
                if args.distributed and not args.horovod:
                    module = model.module
                else:
                    module = model
                roi_extractor = module.encode_pseudo_boxes
                roi_features = roi_extractor(images, rois, normalize=True,
                                             extract_type=args.extract_type)
                mask_pooler = module.encode_masks
                maskpool_features = roi_extractor(images, rois, normalize=True,
                                             extract_type=args.extract_type)
                # maskpool_features = mask_pooler(images, gt_masks_list,
                #                                 normalize=True, mask_attn=args.extract_type == "v1")
                # New way to obtain crop features
                if args.image_ave_pool:
                    feature_map = module.visual.encode_dense(image_crops, keep_shape=True)
                    crop_features = feature_map.mean(dim=(-2, -1))
                    crop_features = F.normalize(crop_features, dim=-1)
                else:
                    crop_features = module.encode_image(image_crops, normalize=True)

                if cast_dtype is not None:
                    roi_features = roi_features.to(dtype=cast_dtype)
                    crop_features = crop_features.to(dtype=cast_dtype)
                    maskpool_features = maskpool_features.to(dtype=cast_dtype)

                roi_logits = roi_features @ cls_embeddings.T
                crop_logits = crop_features @ cls_embeddings.T
                maskpool_logits = maskpool_features @ cls_embeddings.T

            _, roi_top5_inds = roi_logits.topk(5)
            _, crop_top5_inds = crop_logits.topk(5)
            _, maskpool_top5_inds = maskpool_logits.topk(5)
            correct_rois.append(roi_top5_inds == cls_labels.view(-1, 1))
            correct_crops.append(crop_top5_inds == cls_labels.view(-1, 1))
            correct_maskpool.append(maskpool_top5_inds == cls_labels.view(-1, 1))

            similarity_rois.append(torch.gather(roi_logits, dim=1, index=cls_labels.view(-1, 1))[:, 0])
            similarity_crops.append(torch.gather(crop_logits, dim=1, index=cls_labels.view(-1, 1))[:, 0])
            similarity_maskpool.append(torch.gather(maskpool_logits, dim=1, index=cls_labels.view(-1, 1))[:, 0])

            all_cls_labels.append(cls_labels)

        # TODO: gather correct matrix across gpus
        correct_rois = torch.cat(correct_rois).float()
        correct_crops = torch.cat(correct_crops).float()
        correct_maskpool = torch.cat(correct_maskpool).float()
        similarity_rois = torch.cat(similarity_rois).float()
        similarity_crops = torch.cat(similarity_crops).float()
        similarity_maskpool = torch.cat(similarity_maskpool).float()
        all_box_sizes = torch.cat(all_box_sizes)
        all_is_thing = torch.cat(all_is_thing)
        all_cls_labels = torch.cat(all_cls_labels)
        if args.distributed and not args.horovod:
            correct_rois = multi_gpu_sync(correct_rois)
            correct_crops = multi_gpu_sync(correct_crops)
            correct_maskpool = multi_gpu_sync(correct_maskpool)
            all_box_sizes = multi_gpu_sync(all_box_sizes)
            all_is_thing = multi_gpu_sync(all_is_thing)
            similarity_rois = multi_gpu_sync(similarity_rois)
            similarity_crops = multi_gpu_sync(similarity_crops)
            similarity_maskpool = multi_gpu_sync(similarity_maskpool)
            all_cls_labels = multi_gpu_sync(all_cls_labels)

    return correct_rois, correct_crops, correct_maskpool, \
        similarity_rois, similarity_crops, similarity_maskpool, \
        all_box_sizes, all_is_thing, all_cls_labels


def multi_gpu_sync(x):
    device = x.device
    x_list = all_gather(x.cpu())
    x = torch.cat([res.to(device) for res in x_list])
    return x


def macc_with_is_thing(correct_matrix, is_thing, all_cls_labels, prefix):
    def _macc(corrects, cls_labels):
        # return (sum(corrects)/len(corrects)).item()
        min_id = cls_labels.min().item()
        max_id = cls_labels.max().item()
        cand_labels = list(range(min_id, max_id+1))

        acc_per_cls = []

        for lb in cand_labels:
            corrects_per_cls = corrects[cls_labels == lb]
            if corrects_per_cls.shape[0] == 0:
                continue
            acc_per_cls.append(corrects_per_cls.mean().half().item())

        return sum(acc_per_cls) / len(acc_per_cls)

    results = {}
    thing_correct_matrix = correct_matrix[is_thing > 0]
    stuff_correct_matrix = correct_matrix[is_thing < 1]

    thing_cls_labels = all_cls_labels[is_thing > 0].long()
    stuff_cls_labels = all_cls_labels[is_thing < 1].long()

    box_top1_acc = _macc(correct_matrix[:, 0], all_cls_labels)
    box_top5_acc = _macc(correct_matrix.sum(-1), all_cls_labels)

    thing_top1_acc = _macc(thing_correct_matrix[:, 0], thing_cls_labels)
    thing_top5_acc = _macc(thing_correct_matrix.sum(-1), thing_cls_labels)

    stuff_top1_acc = _macc(stuff_correct_matrix[:, 0], stuff_cls_labels)
    stuff_top5_acc = _macc(stuff_correct_matrix.sum(-1), stuff_cls_labels)

    results[f'{prefix}.box.macc1'] = box_top1_acc
    results[f'{prefix}.box.macc5'] = box_top5_acc
    results[f'{prefix}.thing.macc1'] = thing_top1_acc
    results[f'{prefix}.thing.macc5'] = thing_top5_acc
    results[f'{prefix}.stuff.macc1'] = stuff_top1_acc
    results[f'{prefix}.stuff.macc5'] = stuff_top5_acc

    return results


def zero_shot_eval(model, data, epoch, args):
    if 'val' not in data:
        return {}
    if args.zeroshot_frequency == 0:
        return {}
    if (epoch % args.zeroshot_frequency) != 0 and epoch != args.epochs:
        return {}
    logging.info('Region classifier')
    results = {}
    correct_rois, correct_crops, correct_maskpool, \
        similarity_rois, similarity_crops, similarity_maskpool, \
        all_box_sizes, all_is_thing, all_cls_labels = run(model, data['val'].dataloader, args)
    results.update(macc_with_is_thing(correct_rois, all_is_thing, all_cls_labels, 'rois'))
    results.update(macc_with_is_thing(correct_crops, all_is_thing, all_cls_labels, 'crops'))
    results.update(macc_with_is_thing(correct_maskpool, all_is_thing, all_cls_labels, 'maskpool'))

    return results
