from __future__ import print_function
import argparse
import os
import random

import cv2
from tqdm import tqdm
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import CocoDetection
import numpy as np

from segment_anything import sam_model_registry, SamAutomaticMaskGenerator

from measures.feature_measure import maha

def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     np.random.seed(seed)
     random.seed(seed)
    #  torch.backends.cudnn.deterministic = True

# From https://www.kaggle.com/code/yogendrayatnalkar/sam-automatic-semantic-segmentation
def get_mask_embedding_using_patch_embeddings(mask, enc_emb, return_all = False):
    # Converting mask of shape 1024x1024 to shape: 64x64x16x16 
    # This assumes that patch size is 16x16 becuase what we mainly need is: 64x64 at the start
    # We are free to change the patch-size accordingly
    split_mask = np.array(np.split(mask, 64, axis = -1))
    split_mask = np.array(np.split(split_mask, 64, axis = -2))
    split_mask = split_mask*1 # split_mask is a mask of shape: 64x64x16x16 
    # split_mask is binary (have value of 0 or 1 not between)
    
    # Converting split_mask of shape: 64x64x16x16 to 64x64 
    # by adding all numbers in every 16x16 grid
    split_mask = np.sum(split_mask, axis = -1)
    split_mask = np.sum(split_mask, axis = -1)
    
    # Get all patch embeddings from this split_mask of 64x64
    # In this split_mask, at all locations where the cell-value is greater than 1,
    # It means that we need to pick the patch-embeddding at this given index (X,Y) value
    patch_locations = np.where(split_mask > 1)
    n_patch_embeddings = enc_emb[patch_locations]
    mask_embedding = n_patch_embeddings.mean(axis = 0, keepdims = False)
    
    if return_all:
        return mask_embedding, patch_locations, n_patch_embeddings
    return mask_embedding



def predict_maha(model, samples, trainloader, args):
    feature_all = []
    labels_all = []
    img_id_all = []
    label_id_all = []
    area_all = []
    for i in range(samples):
        for batch_idx, (imgs, labels) in enumerate(tqdm(trainloader)):
            with torch.no_grad():
                # Generate embeddings
                model.predictor.set_image(imgs[0])
                enc_emb = model.predictor.features
                enc_emb = enc_emb.to("cpu").numpy()
                enc_emb = enc_emb[0].transpose((1,2,0))

                # Process mask
                mask_gt_list = []
                label_list = []
                for ann in labels[0]:
                    original_segmentation_mask = trainloader.dataset.coco.annToMask(ann)
                    # Hard code for sam
                    resized_segmentation_mask = cv2.resize(original_segmentation_mask, (1024, 1024), interpolation=cv2.INTER_NEAREST)
                    
                    resized_segmentation_mask = resized_segmentation_mask.astype(bool)
                    mask_obj = {
                        "segmentation": resized_segmentation_mask,
                        "area": resized_segmentation_mask.sum(),
                        "category_id": ann["category_id"]
                    }
                    mask_gt_list.append(mask_obj)
                    label_list.append(ann['category_id'])
                    img_id_all.append(ann['image_id'])
                    label_id_all.append(ann['id'])
                    area_all.append(ann['area'])
                
                for mask in mask_gt_list:
                    nth_mask_emb = get_mask_embedding_using_patch_embeddings(mask['segmentation'], enc_emb)
                    feature_all.append(nth_mask_emb)
                labels_all += label_list

    feature_all = np.stack(feature_all, axis=0)
    labels_all = np.array(labels_all)
    print(feature_all.shape)
    print(labels_all.shape)

    maha_path = os.path.join(args.maha_dir, args.dataset_name, args.arch)
    os.makedirs(maha_path, exist_ok=True)
    img_id_all = np.array(img_id_all)
    label_id_all = np.array(label_id_all)
    area_all = np.array(area_all)
    np.save(os.path.join(maha_path, 'img_id.npy'), img_id_all)
    np.save(os.path.join(maha_path, 'label_id.npy'), label_id_all)
    np.save(os.path.join(maha_path, 'features.npy'), feature_all)
    np.save(os.path.join(maha_path, 'labels.npy'), labels_all)
    np.save(os.path.join(maha_path, 'area.npy'), area_all)

    # maha_intermediate_dict = maha(feature_all,labels_all,indist_classes = args.num_classes)
    # maha_path = os.path.join(args.maha_dir, args.dataset_name, args.arch)
    # os.makedirs(maha_path, exist_ok=True)
    # np.save(os.path.join(maha_path, 'maha_dict.npy'), maha_intermediate_dict)



def main():
    parser = argparse.ArgumentParser(description='Modeling data uncertainty with pre-trained detection models')
    # pretrained models setting
    parser.add_argument('--maha_dir', type=str)
    parser.add_argument('-a', '--arch', metavar='ARCH', default='SAM')
    parser.add_argument('--pretrained_model', type=str, help='large pretrained models')
    

    parser.add_argument('--batch_size', default=16, type=int)
    parser.add_argument('--dataset_dir', type=str)
    parser.add_argument('--dataset_name', default='coco', type=str, help='cifar10/cifar100')
    parser.add_argument('--num_classes', default=80, type=int)
    parser.add_argument('--random_state', type=int, default=42)

    args = parser.parse_args()

    batch_size = args.batch_size

    setup_seed(args.random_state)

    print(args) 
    # Dataloader
    print('\n[Phase 1] : Data Preparation')
    if args.dataset_name == "coco":
        coco_transform = transforms.Compose([
            transforms.Resize([1024, 1024]),
            transforms.ToTensor(),
        ])
        train_annotation_file = os.path.join(args.dataset_dir, args.dataset_name, "annotations", "instances_train2017.json")
        train_dataset = CocoDetection(root=os.path.join(args.dataset_dir, args.dataset_name, "train2017"), annFile=train_annotation_file, transform=coco_transform)


        # Only supports batch_size = 1 for now
        batch_size = 1
        def collate_fn_coco(batch):
            return tuple(zip(*batch))
        train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, num_workers=4, collate_fn=collate_fn_coco)
        

    print('\n[Phase 2] : Model Preparation')
    if args.arch == "SAM":
        sam = sam_model_registry["vit_h"](checkpoint=args.pretrained_model)
        sam = sam.cuda()
        mask_generator = SamAutomaticMaskGenerator(sam)
        pass
    else:
        raise NotImplementedError()

    print('\n[Phase 3] : Generate Maha Matrix')
    predict_maha(
        model=mask_generator, 
        samples=1, 
        trainloader=train_dataloader,
        args=args)


if __name__ == '__main__':
    main()