import os
import argparse
from pathlib import Path
from glob import glob
from tqdm import tqdm
import torch
from PIL import Image
import time
from utils.annotations_worker import CocoAnnotationsWorker
import json
import numpy as np
from scipy import ndimage
from crf import densecrf
import torch.nn.functional as F
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap


def load_eig_vecs(eigenvec_dir, num_eig_vecs, image_name):
    """
    Load the eigen vectors for the image in the format of a dictionary votecut method expects
    :param eigenvec_dirs: list of directories containing the eigen vectors
    :param num_eig_vecs: number of eigen vectors to use from each directory
    :param image_name: name of the image without the extension
    :return:
    """
    eig_vec_path = os.path.join(eigenvec_dir, f"{image_name}.pt")
    eig_vec = torch.load(eig_vec_path)
    # print(eig_vec.shape)
    eig_vec = eig_vec[:,:,0]
    # eig_vec = eig_vec[:,:,:num_eig_vecs]
    # print(eig_vec.shape)
    # if eig_vec.shape[1] == 900:
    #         eig_vec = eig_vec.reshape(eig_vec.shape[0], 30, 30)
    # elif eig_vec.shape[1] == 3600:
    #     eig_vec = eig_vec.reshape(eig_vec.shape[0], 60, 60)
    # elif eig_vec.shape[1] == 1156:
    #     eig_vec = eig_vec.reshape(eig_vec.shape[0], 34, 34)
    # else:
    #     raise ValueError("Invalid eig vec shape")
    return eig_vec

def check_num_fg_corners(bipartition, dims):
    # check number of corners belonging to the foreground
    bipartition_ = bipartition.reshape(dims)
    top_l, top_r, bottom_l, bottom_r = bipartition_[0][0], bipartition_[0][-1], bipartition_[-1][0], bipartition_[-1][-1]
    nc = int(top_l) + int(top_r) + int(bottom_l) + int(bottom_r)
    return nc

def parse_image_file(image_full_path):
    image_name = Path(image_full_path).stem.split(".")[0]
    # in case the file name starts with ILSVRC2012 remove it, it is the validation prefix
    image_id = image_name[len("ILSVRC2012"):] if image_name.startswith("ILSVRC2012") else image_name
    image_id = int("".join(filter(str.isdigit, image_id)))
    return image_name, image_id

def num_corners_on_border_mask(mask):
    """
    :param mask: binary mask of shape (H, W)
    """
    # check if there is an overlap between the bbox and at list 2 image borders
    num_of_corners_on_border = int(mask[0, 0]) + int(mask[0, -1]) + int(mask[-1, 0]) + int(mask[-1, -1])
    return num_of_corners_on_border

def compute_neighbor_diff_sum(tensor, neighbors=4):
   h, w = tensor.shape
   
   padded = torch.zeros(h+2, w+2, dtype=tensor.dtype, device=tensor.device)
   padded[1:h+1, 1:w+1] = tensor
   
   padded[0, 1:w+1] = tensor[0, :]
   padded[h+1, 1:w+1] = tensor[h-1, :]
   padded[1:h+1, 0] = tensor[:, 0]
   padded[1:h+1, w+1] = tensor[:, w-1]
   
   padded[0, 0] = tensor[0, 0]
   padded[0, w+1] = tensor[0, w-1] 
   padded[h+1, 0] = tensor[h-1, 0]
   padded[h+1, w+1] = tensor[h-1, w-1]

   center = tensor
   
   if neighbors == 4:
       up = padded[0:h, 1:w+1]
       down = padded[2:h+2, 1:w+1] 
       left = padded[1:h+1, 0:w]
       right = padded[1:h+1, 2:w+2]
       
       diff_sum = (torch.abs(center - up) + 
                   torch.abs(center - down) + 
                   torch.abs(center - left) + 
                   torch.abs(center - right))
       diff_avg = diff_sum / 4
   
   elif neighbors == 8:
       up = padded[0:h, 1:w+1]
       down = padded[2:h+2, 1:w+1] 
       left = padded[1:h+1, 0:w]
       right = padded[1:h+1, 2:w+2]
       up_left = padded[0:h, 0:w]
       up_right = padded[0:h, 2:w+2]
       down_left = padded[2:h+2, 0:w]
       down_right = padded[2:h+2, 2:w+2]
       
       diff_sum = (torch.abs(center - up) + 
                   torch.abs(center - down) + 
                   torch.abs(center - left) + 
                   torch.abs(center - right) +
                   torch.abs(center - up_left) +
                   torch.abs(center - up_right) +
                   torch.abs(center - down_left) +
                   torch.abs(center - down_right))
       diff_avg = diff_sum / 8
   
#    mean_val = diff_avg.mean()
#    diff_avg = torch.where(diff_avg > mean_val, diff_avg, torch.zeros_like(diff_avg))
   return diff_avg

def filter_arrays_by_threshold(sort_sums, sort_counts, sort_labels, threshold=0.95):
   total_sums = sum(sort_sums)
   total_counts = sum(sort_counts)
   
   m = len(sort_sums)
   n = len(sort_counts)
   
   cumsum_sums = 0
   for i, val in enumerate(sort_sums):
       cumsum_sums += val
       if cumsum_sums / total_sums > threshold:
           m = i + 1
           break
   
   cumsum_counts = 0
   for i, val in enumerate(sort_counts):
       cumsum_counts += val
       if cumsum_counts / total_counts > threshold:
           n = i + 1
           break
   
   keep_items = min(m, n)
   return sort_sums[:keep_items], sort_counts[:keep_items], sort_labels[:keep_items]

def get_masks(eigen_vec):
    avg = torch.mean(eigen_vec)
    bipartition = eigen_vec > avg

    if num_corners_on_border_mask(bipartition) >= 3:
        eigen_vec = eigen_vec * -1

    eigen_vec = eigen_vec - compute_neighbor_diff_sum(eigen_vec, neighbors=8)
    avg = torch.mean(eigen_vec)
    bipartition = eigen_vec > avg

    # bipartition = ndimage.binary_fill_holes(bipartition)

    objects, num_objects = ndimage.label(bipartition) 
    if num_objects < 1:
        return []
        # raise ValueError('num_objects < 1, algorithnm fail')

    labels, counts = np.unique(objects, return_counts=True)

    sums = ndimage.sum(eigen_vec, labels=objects, index=labels)
    order = np.argsort(sums)[::-1]
    sort_counts = counts[order]
    sort_labels = labels[order]
    sort_sums  = sums[order]

    sort_sums = sort_sums[:-1]
    sort_labels = sort_labels[:-1]
    sort_counts = sort_counts[:-1]
    sort_sums, sort_counts, sort_labels = filter_arrays_by_threshold(sort_sums, sort_counts, sort_labels)
    masks = []
    # min_mask_w, min_mask_h = 5, 5
    for idx, cc in enumerate(sort_labels):
        # if idx == 5:
        #     break
        # if idx == 0:
        #     first_obj_count = sort_counts[idx]
        # elif sort_counts[idx] < 20:
        #     break
        # elif sort_counts[idx] < (first_obj_count * 0.1):
        #     break
        # if sort_counts[idx] < 20:
        #     break
        mask_pos = np.where(objects == cc)
        pseudo_mask = np.zeros_like(bipartition).astype(np.uint8)
        pseudo_mask[mask_pos[0],mask_pos[1]] = 1
        # if num_corners_on_border_mask(pseudo_mask) >= 2:
        #         continue
        # bbox = bbox_from_mask(pseudo_mask)
        # # if bbox is too small continue
        # if bbox[2] < min_mask_w and bbox[3] < min_mask_h:
        #     continue
        # Image.fromarray((pseudo_mask * 255).astype(np.uint8)).save(f"debug/mask_get_{idx}.png")
        masks.append(pseudo_mask)

    # print('get masks:', len(masks))
    return masks

def bbox_from_mask(mask: np.ndarray):
    # bbox format is [x, y, width, height]
    x = np.where(mask.sum(axis=0))[0]
    y = np.where(mask.sum(axis=1))[0]
    bbox = [np.min(x), np.min(y), np.max(x) - np.min(x) + 1, np.max(y) - np.min(y) + 1]
    return np.array(bbox)

def IoU_bbox(mask1, mask2):
    """
    This method calculates the IoU between the two bboxes of mask1 and mask2.
    :param mask1:
    :param mask2:
    :return:
    """
    bbox_1 = bbox_from_mask(mask1)
    bbox_2 = bbox_from_mask(mask2)
    # calculate the intersection area
    x1 = max(bbox_1[0], bbox_2[0])
    y1 = max(bbox_1[1], bbox_2[1])
    x2 = min(bbox_1[0] + bbox_1[2], bbox_2[0] + bbox_2[2])
    y2 = min(bbox_1[1] + bbox_1[3], bbox_2[1] + bbox_2[3])
    intersection_area = max(0, x2 - x1) * max(0, y2 - y1)
    # calculate the union area
    union_area = bbox_1[2] * bbox_1[3] + bbox_2[2] * bbox_2[3] - intersection_area
    return intersection_area / union_area

def IoU(mask1, mask2):
    mask1, mask2 = (mask1>0.5).to(torch.bool), (mask2>0.5).to(torch.bool)
    intersection = torch.sum(mask1 * (mask1 == mask2), dim=[-1, -2]).squeeze()
    union = torch.sum(mask1 + mask2, dim=[-1, -2]).squeeze()
    return (intersection.to(torch.float) / union).mean().item()

def mask_post_processing_new(mask, image_rgb, device='cpu'):
    """
    Post-processing of the mask. It performs crf and returns the final mask in the original image size.
    In case of crf failure, it returns the original mask.
    mask: numpy array of shape [height, width] with [0,1] values
    image_rgb: PIL image
    return: tuple - (mask as numpy array of shape [height, width] with [0,1] values, success flag)
    """
    success = True
    image_orig_size = image_rgb.size
    rescale_size = (image_orig_size[1], image_orig_size[0])
    # resizes the mask to the original image size with nearest neighbor interpolation
    mask = mask.astype(np.float32)
    patches_mask = F.interpolate(torch.from_numpy(mask[None, None, :, :]), size=rescale_size, mode='nearest')[0][0].numpy()
    return patches_mask, success
    # crop the mask by the bounding box
    bbox = bbox_from_mask(patches_mask)
    x, y, w, h = bbox
    factor = 0.33
    crop_x = (max(x - int(w*factor), 0), min((x + w) + int(w*factor), rescale_size[1]))
    crop_y = (max(y - int(h*factor), 0), min((y + h) + int(h*factor), rescale_size[0]))
    mask_cropped = patches_mask[crop_y[0]:crop_y[1], crop_x[0]:crop_x[1]]
    # crop the image by the bounding box
    img = np.asarray(image_rgb).copy()
    # print(img.shape)
    # exit()
    img_cropped = img[crop_y[0]:crop_y[1], crop_x[0]:crop_x[1], :]
    # apply CRF to the bounding box
    pseudo_mask_crop = densecrf(img_cropped, mask_cropped)
    pseudo_mask_crop = ndimage.binary_fill_holes(pseudo_mask_crop)
    # pseudo_mask_crop = ndimage.binary_fill_holes(pseudo_mask_crop >= 0.5)
    # create a pseudo mask with the same size as the original image
    pseudo_mask = np.zeros_like(patches_mask)
    pseudo_mask[crop_y[0]:crop_y[1], crop_x[0]:crop_x[1]] = pseudo_mask_crop
    # in case crf did not provide a mask or the IoU between the original mask and the pseudo mask is too different
    # Image.fromarray((patches_mask*255).astype(np.uint8)).save(f"debug/mask.png")
    # Image.fromarray((pseudo_mask*255).astype(np.uint8)).save(f"debug/mask_crf.png")
    # we consider the mask as not an object
    # if np.sum(pseudo_mask) == 0 or IoU(torch.from_numpy(patches_mask).to(device), torch.from_numpy(pseudo_mask).to(device)) < 0.5:
    #     # print('iou: ', iou)
    #     return patches_mask, False
    
    return pseudo_mask, success

def cutonce(image_rgb, eig_vec, device):
    bipartitions = get_masks(eig_vec)

    pseudo_masks = []
    for idx, bipartition in enumerate(bipartitions):
        # pseudo_mask,success = mask_post_processing(bipartition, I)
        # pseudo_mask,success = mask_post_processing_offical(bipartition, I)
        pseudo_mask,success = mask_post_processing_new(bipartition, image_rgb)
        if not success:
            continue
        pseudo_mask = pseudo_mask.astype(np.uint8)
        pseudo_masks.append(pseudo_mask)

    return pseudo_masks

def create_cutonce_annotations(eigenvec_dir, img_files, worker_dir,
                               num_eig_vecs=1, save_period=100, device="cpu", resume=False):
    """
    This is a method for a single job that creates the pseudo labels for the images using votecut method and save them
    to a temporary file in order to be aggregated later. That way we can parallelize the process of creating the pseudo
    labels for the images, and also saving RAM by not keeping all the annotations in memory.
    :param eigenvec_dirs: list of directories containing the eigen vectors
    :param img_files: list of image files to process
    :param worker_dir: directory to save the temporary files
    :param tau_m: tau_m to use for votecut
    :param num_eig_vecs: number of eigen vectors to use
    :param save_period: saving period for the annotations in temp files
    :param device:
    :param resume:
    :return:
    """
    ts = time.time()
    ann_worker = CocoAnnotationsWorker(worker_dir)
    # if the worker directory exists and we are not resuming the process clear it
    if resume:
        img_files = ann_worker.resume(img_files)
    else:
        ann_worker.cleanup()

    num_files = len(img_files)
    if num_files == 0:
        print("No images left to process, exiting...")
        return

    Path(worker_dir).mkdir(parents=True, exist_ok=True)
    # just for tracking the skipped images
    skipped_images_file = os.path.join(worker_dir, "skipped_images.txt")

    for ind, img_file in enumerate(tqdm(img_files, desc="Creating pseudo labels")):
        # try:
        image_name, image_id = parse_image_file(img_file)
        # load all eigen vectors for the image
        image_rgb = Image.open(img_file).convert("RGB")
        # load all eigen vectors for the image
        eig_vec = load_eig_vecs(eigenvec_dir, num_eig_vecs ,image_name)
        # perform votecut on the image
        image_masks = cutonce(image_rgb, eig_vec, device=device)
        # add the image annotations
        success = ann_worker.add_image_ann(image_id=image_id,
                                            file_name=img_file,
                                            height=image_rgb.size[1],
                                            width=image_rgb.size[0],
                                            image_masks=image_masks)
        # write the image file to the existing files
        if not success:
            print(f"Failed to add image {img_file} to the annotations")
            with open(skipped_images_file, "a") as f:
                f.write(f"{image_name}\n")
            continue
        # except Exception as e:
        #     print(f"Error: {e}")
        # save the annotations to temp file for aggregation
        if (ind + 1) % save_period == 0:
            ann_worker.flush_and_save_anns()
    # save the leftover annotations
    if num_files % save_period != 0:
        print(f'Save last {num_files % save_period} files')
        ann_worker.flush_and_save_anns()
    # ann_worker.done()
    te = time.time()
    print(f"Running Time: {te - ts}")
    print("Done!")


if __name__ == "__main__":
    parser = argparse.ArgumentParser("Create pseudo labels mask coco annotation file")
    parser.add_argument("--dataset-root", type=str, default="/data/xxx/datasets/coco/val2017", help="Path to coco dataset")
    parser.add_argument("--split", type=str, default="val", choices=["train", "val"], help="Split to use")
    parser.add_argument("--out-file", type=str, default="pseudo_labels/coco_val2017_cutonce_nofix.json", help="")
    parser.add_argument("--models", nargs='+',
                        default=["dino_s16", "dinov2_b14", "dinov2_s14", "dino_b16", "dino_s8", "dino_b8"],
                        help="List of models to use")
    parser.add_argument("--eig-vec-dir", type=str, default="datasets/eig_vecs_coco_val", help="Directory of images eigen vectors for each model")
    parser.add_argument("--num-eig-vecs", type=int, default=1, help="Number of eigen vectors to use")
    parser.add_argument("--save-period", type=int, default=100, help="saving period for the annotations in temp files")
    parser.add_argument("--tmp-folder", type=str, default="tmp", help="Directory to save temp files")
    parser.add_argument("--save-tmp-files", action="store_true", help="Save temp files")
    parser.add_argument("--resume", type=bool, default=False, help="Resume from previous run")
    parser.add_argument("--device", type=str, default="cpu", choices=["cuda", "cpu"])
    args = parser.parse_args()

    tmp_folder = args.tmp_folder
    os.makedirs(tmp_folder, exist_ok=True)
    args.models = args.models[-1:]
    # eigenvec_dirs = [f"{args.eig_vec_dir}/{model}" for model in args.models]
    # print(f'eigenvec_dirs: {eigenvec_dirs}')

    all_image_files = sorted(glob(f"{args.dataset_root}/*.jpg"))
    # all_image_files = all_image_files[:50]

    # all_image_files = [f"{args.dataset_root}/000000007511.jpg"]
    # exit()
    eigenvec_dir = "eigen_vecs_nofix/dino_b8"
    create_cutonce_annotations(eigenvec_dir, all_image_files, tmp_folder, args.num_eig_vecs, args.save_period, args.device, args.resume)
    # anns_files = sorted(os.listdir(tmp_folder))
    os.makedirs(tmp_folder, exist_ok=True)
    anns_files = os.listdir(tmp_folder)

    anns_paths = [os.path.join(tmp_folder, fname) for fname in anns_files]
    print(anns_files)
    # exit()
    anns = CocoAnnotationsWorker.collect_to_single_ann_dict(anns_paths)
    with open(args.out_file, "w") as f:
        json.dump(anns, f)
        print(f'dump {args.out_file}')
    
    from eval_coco_json import eval_coco_json
    eval_coco_json('coco', args.out_file)
    exit(0)
    if not args.save_tmp_files:
        print('cleanup_tmp_files')
        CocoAnnotationsWorker.cleanup_tmp_files(tmp_folder)
    exit(0)
