import os
import argparse
from pathlib import Path
from glob import glob
from tqdm import tqdm
import torch
from PIL import Image
import time
from utils.annotations_worker import CocoAnnotationsWorker
import json
import numpy as np
from scipy import ndimage
from crf import densecrf
import torch.nn.functional as F
import matplotlib.pyplot as plt

def load_eig_vecs(eigenvec_dir, num_eig_vecs, image_name):
    """
    Load the eigen vectors for the image in the format of a dictionary votecut method expects
    :param eigenvec_dirs: list of directories containing the eigen vectors
    :param num_eig_vecs: number of eigen vectors to use from each directory
    :param image_name: name of the image without the extension
    :return:
    """
    eig_vec_path = os.path.join(eigenvec_dir, f"{image_name}.pt")
    eig_vec = torch.load(eig_vec_path)
    eig_vec = eig_vec.T
    eig_vec = eig_vec[:num_eig_vecs]
    if eig_vec.shape[1] == 900:
            eig_vec = eig_vec.reshape(eig_vec.shape[0], 30, 30)
    elif eig_vec.shape[1] == 3600:
        eig_vec = eig_vec.reshape(eig_vec.shape[0], 60, 60)
    elif eig_vec.shape[1] == 1156:
        eig_vec = eig_vec.reshape(eig_vec.shape[0], 34, 34)
    else:
        raise ValueError("Invalid eig vec shape")
    return eig_vec

def parse_image_file(image_full_path):
    image_name = Path(image_full_path).stem.split(".")[0]
    # in case the file name starts with ILSVRC2012 remove it, it is the validation prefix
    image_id = image_name[len("ILSVRC2012"):] if image_name.startswith("ILSVRC2012") else image_name
    image_id = int("".join(filter(str.isdigit, image_id)))
    return image_name, image_id

def num_corners_on_border_mask(mask):
    """
    :param mask: binary mask of shape (H, W)
    """
    # check if there is an overlap between the bbox and at list 2 image borders
    num_of_corners_on_border = int(mask[0, 0]) + int(mask[0, -1]) + int(mask[-1, 0]) + int(mask[-1, -1])
    return num_of_corners_on_border

def compute_neighbor_diff_sum(tensor, neighbors=4):
   h, w = tensor.shape
   
   padded = torch.zeros(h+2, w+2, dtype=tensor.dtype, device=tensor.device)
   padded[1:h+1, 1:w+1] = tensor
   
   padded[0, 1:w+1] = tensor[0, :]
   padded[h+1, 1:w+1] = tensor[h-1, :]
   padded[1:h+1, 0] = tensor[:, 0]
   padded[1:h+1, w+1] = tensor[:, w-1]
   
   padded[0, 0] = tensor[0, 0]
   padded[0, w+1] = tensor[0, w-1] 
   padded[h+1, 0] = tensor[h-1, 0]
   padded[h+1, w+1] = tensor[h-1, w-1]

   center = tensor
   
   if neighbors == 4:
       up = padded[0:h, 1:w+1]
       down = padded[2:h+2, 1:w+1] 
       left = padded[1:h+1, 0:w]
       right = padded[1:h+1, 2:w+2]
       
       diff_sum = (torch.abs(center - up) + 
                   torch.abs(center - down) + 
                   torch.abs(center - left) + 
                   torch.abs(center - right))
       diff_avg = diff_sum / 4
   
   elif neighbors == 8:
       up = padded[0:h, 1:w+1]
       down = padded[2:h+2, 1:w+1] 
       left = padded[1:h+1, 0:w]
       right = padded[1:h+1, 2:w+2]
       up_left = padded[0:h, 0:w]
       up_right = padded[0:h, 2:w+2]
       down_left = padded[2:h+2, 0:w]
       down_right = padded[2:h+2, 2:w+2]
       
       diff_sum = (torch.abs(center - up) + 
                   torch.abs(center - down) + 
                   torch.abs(center - left) + 
                   torch.abs(center - right) +
                   torch.abs(center - up_left) +
                   torch.abs(center - up_right) +
                   torch.abs(center - down_left) +
                   torch.abs(center - down_right))
       diff_avg = diff_sum / 8
   
#    mean_val = diff_avg.mean()
#    diff_avg = torch.where(diff_avg > mean_val, diff_avg, torch.zeros_like(diff_avg))
   return diff_avg

def filter_arrays_by_threshold(sort_sums, sort_counts, sort_labels, threshold=0.95):
   total_sums = sum(sort_sums)
   total_counts = sum(sort_counts)
   
   m = len(sort_sums)
   n = len(sort_counts)
   
   cumsum_sums = 0
   for i, val in enumerate(sort_sums):
       cumsum_sums += val
       if cumsum_sums / total_sums > threshold:
           m = i + 1
           break
   
   cumsum_counts = 0
   for i, val in enumerate(sort_counts):
       cumsum_counts += val
       if cumsum_counts / total_counts > threshold:
           n = i + 1
           break
   
   keep_items = min(m, n)
   return sort_sums[:keep_items], sort_counts[:keep_items], sort_labels[:keep_items]

def get_masks(eigen_vec):
    avg = torch.mean(eigen_vec)
    bipartition = eigen_vec > avg
    
    if num_corners_on_border_mask(bipartition) >= 3:
        eigen_vec = eigen_vec * -1
    elif torch.abs(torch.min(eigen_vec)) > 6 * torch.abs(torch.max(eigen_vec)):
        eigen_vec = eigen_vec * -1

    eigen_vec = eigen_vec - compute_neighbor_diff_sum(eigen_vec, neighbors=8)
    avg = torch.mean(eigen_vec)
    bipartition = eigen_vec > avg

    objects, num_objects = ndimage.label(bipartition) 
    if num_objects < 1:
        return []
        # raise ValueError('num_objects < 1, algorithnm fail')

    labels, counts = np.unique(objects, return_counts=True)

    sums = ndimage.sum(eigen_vec, labels=objects, index=labels)
    order = np.argsort(sums)[::-1]
    sort_counts = counts[order]
    sort_labels = labels[order]
    sort_sums  = sums[order]

    sort_sums = sort_sums[:-1]
    sort_labels = sort_labels[:-1]
    sort_counts = sort_counts[:-1]
    sort_sums, sort_counts, sort_labels = filter_arrays_by_threshold(sort_sums, sort_counts, sort_labels)
    masks = []
    for idx, cc in enumerate(sort_labels):
        mask_pos = np.where(objects == cc)
        pseudo_mask = np.zeros_like(bipartition).astype(np.uint8)
        pseudo_mask[mask_pos[0],mask_pos[1]] = 1
        masks.append(pseudo_mask)

    # print('get masks:', len(masks))
    return masks

def bbox_from_mask(mask: np.ndarray):
    # bbox format is [x, y, width, height]
    x = np.where(mask.sum(axis=0))[0]
    y = np.where(mask.sum(axis=1))[0]
    bbox = [np.min(x), np.min(y), np.max(x) - np.min(x) + 1, np.max(y) - np.min(y) + 1]
    return np.array(bbox)

def mask_post_processing_new(mask, image_rgb, device='cpu'):
    """
    Post-processing of the mask. It performs crf and returns the final mask in the original image size.
    In case of crf failure, it returns the original mask.
    mask: numpy array of shape [height, width] with [0,1] values
    image_rgb: PIL image
    return: tuple - (mask as numpy array of shape [height, width] with [0,1] values, success flag)
    """
    success = True
    image_orig_size = image_rgb.size
    rescale_size = (image_orig_size[1], image_orig_size[0])
    # resizes the mask to the original image size with nearest neighbor interpolation
    patches_mask = F.interpolate(torch.from_numpy(mask[None, None, :, :]), size=rescale_size, mode='nearest')[0][0].numpy()
    patches_mask = ndimage.binary_fill_holes(patches_mask)
    return patches_mask, success
    # crop the mask by the bounding box
    bbox = bbox_from_mask(patches_mask)
    x, y, w, h = bbox
    factor = 0.33
    crop_x = (max(x - int(w*factor), 0), min((x + w) + int(w*factor), rescale_size[1]))
    crop_y = (max(y - int(h*factor), 0), min((y + h) + int(h*factor), rescale_size[0]))
    mask_cropped = patches_mask[crop_y[0]:crop_y[1], crop_x[0]:crop_x[1]]
    # crop the image by the bounding box
    img = np.asarray(image_rgb).copy()
    # print(img.shape)
    # exit()
    img_cropped = img[crop_y[0]:crop_y[1], crop_x[0]:crop_x[1], :]
    # apply CRF to the bounding box
    try:
        pseudo_mask_crop = densecrf(img_cropped, mask_cropped)
        # ---------------- is delete
        pseudo_mask_crop = ndimage.binary_fill_holes(pseudo_mask_crop)
        # create a pseudo mask with the same size as the original image
        pseudo_mask = np.zeros_like(patches_mask)
        pseudo_mask[crop_y[0]:crop_y[1], crop_x[0]:crop_x[1]] = pseudo_mask_crop
        binary_mask = pseudo_mask
    except Exception as e:
        # in case crf failed for some reason use the original mask
        binary_mask = patches_mask
        success = False
    return binary_mask, success

def cutonce(image_rgb, eig_vec, device):
    eig_vec = eig_vec[0]
    bipartitions = get_masks(eig_vec)

    pseudo_masks = []
    num_masks = len(bipartitions)
    for idx, bipartition in enumerate(bipartitions):
        # pseudo_mask,success = mask_post_processing(bipartition, I)
        # pseudo_mask,success = mask_post_processing_offical(bipartition, I)
        pseudo_mask,success = mask_post_processing_new(bipartition, image_rgb)
        if not success:
            continue
        pseudo_mask = pseudo_mask.astype(np.uint8)
        if num_masks == 1:
            score = 1.0
        else:
            score = 1.0 - idx / (2 * num_masks - 2)
        # score=1.0
        pseudo_masks.append({'data':pseudo_mask, 'score':score})

    return pseudo_masks

def create_cutonce_annotations(eigenvec_dir, img_files, worker_dir,
                               num_eig_vecs=1, save_period=100, device="cpu", resume=False):
    """
    This is a method for a single job that creates the pseudo labels for the images using votecut method and save them
    to a temporary file in order to be aggregated later. That way we can parallelize the process of creating the pseudo
    labels for the images, and also saving RAM by not keeping all the annotations in memory.
    :param eigenvec_dirs: list of directories containing the eigen vectors
    :param img_files: list of image files to process
    :param worker_dir: directory to save the temporary files
    :param tau_m: tau_m to use for votecut
    :param num_eig_vecs: number of eigen vectors to use
    :param save_period: saving period for the annotations in temp files
    :param device:
    :param resume:
    :return:
    """
    ts = time.time()
    ann_worker = CocoAnnotationsWorker(worker_dir)
    # if the worker directory exists and we are not resuming the process clear it
    if resume:
        img_files = ann_worker.resume(img_files)
    else:
        ann_worker.cleanup()

    num_files = len(img_files)
    if num_files == 0:
        print("No images left to process, exiting...")
        return

    Path(worker_dir).mkdir(parents=True, exist_ok=True)
    # just for tracking the skipped images
    skipped_images_file = os.path.join(worker_dir, "skipped_images.txt")

    for ind, img_file in enumerate(tqdm(img_files, desc="Creating pseudo labels")):
        try:
            image_name, image_id = parse_image_file(img_file)
            # load all eigen vectors for the image
            image_rgb = Image.open(img_file).convert("RGB")
            # load all eigen vectors for the image
            eig_vec = load_eig_vecs(eigenvec_dir, num_eig_vecs ,image_name)
            # perform votecut on the image
            image_masks = cutonce(image_rgb, eig_vec, device=device)
            # add the image annotations
            success = ann_worker.add_image_ann(image_id=image_id,
                                                file_name=img_file,
                                                height=image_rgb.size[1],
                                                width=image_rgb.size[0],
                                                image_masks=image_masks)
            # write the image file to the existing files
            if not success:
                print(f"Failed to add image {img_file} to the annotations")
                with open(skipped_images_file, "a") as f:
                    f.write(f"{image_name}\n")
                continue
        except Exception as e:
            print(f"Error in cutonce: {e}")
        # save the annotations to temp file for aggregation
        if (ind + 1) % save_period == 0:
            ann_worker.flush_and_save_anns()
    # save the leftover annotations
    if num_files % save_period != 0:
        print(f'Save last {num_files % save_period} files')
        ann_worker.flush_and_save_anns()
    # ann_worker.done()
    te = time.time()
    print(f"Running Time: {te - ts}")
    print("Done!")

DATASET_PATH={
    "imagenet_train": "/data/xxx/datasets/imagenet/train",
    "imagenet_val": "/data/xxx/datasets/imagenet/val",
    "coco_val2017": "/data/xxx/datasets/coco/val2017",
}

if __name__ == "__main__":
    parser = argparse.ArgumentParser("Create pseudo labels mask coco annotation file")
    parser.add_argument("--dataset", type=str, default="imagenet_val", choices=["imagenet_train", "imagenet_val", "coco_val2017"], help="Dataset")
    parser.add_argument("--split", type=str, default="val", choices=["train", "val"], help="Split to use")
    parser.add_argument("--models", nargs='+',
                        default=["dino_s16", "dinov2_b14", "dinov2_s14", "dino_b16", "dino_s8", "dino_b8"],
                        help="List of models to use")
    parser.add_argument("--eig-vec-dir", type=str, default="", help="Directory of images eigen vectors for each model")
    parser.add_argument("--num-eig-vecs", type=int, default=1, help="Number of eigen vectors to use")
    parser.add_argument("--save-period", type=int, default=100, help="saving period for the annotations in temp files")
    parser.add_argument("--tmp-folder", type=str, default="tmp_imagenet", help="Directory to save temp files")
    parser.add_argument("--save-tmp-files", action="store_true", help="Save temp files")
    parser.add_argument("--resume", type=bool, default=False, help="Resume from previous run")
    parser.add_argument("--device", type=str, default="cpu", choices=["cuda", "cpu"])
    args = parser.parse_args()
    print(args)

    dataset_root = DATASET_PATH[args.dataset]
    if args.dataset == "imagenet_train":
        img_files = glob(f"{dataset_root}/*/*.JPEG")
    elif args.dataset == "imagenet_val":
        img_files = glob(f"{dataset_root}/*.JPEG")
    elif args.dataset == "coco_val2017":
        img_files = glob(f"{dataset_root}/*.jpg")
    else:
        raise ValueError(f"Invalid dataset: {args.dataset} provided.")
    
    tmp_folder = args.tmp_folder
    os.makedirs(tmp_folder, exist_ok=True)

    # img_files = img_files[:10]
    # print(img_files)
    # eigenvec_dir = f"eigen_vecs_official/{args.dataset}/dino_b8"
    # eigenvec_dir = f"eigen_vecs_new/{args.dataset}/dino_b8"
    eigenvec_dir = f"eigen_vecs_no_binary/{args.dataset}/dino_b8"
    print(f"eigenvec_dir: {eigenvec_dir}")
    # img_files=['/data/xxx/datasets/imagenet/val/ILSVRC2012_val_00009230.JPEG']
    create_cutonce_annotations(eigenvec_dir, img_files, tmp_folder, args.num_eig_vecs, args.save_period, args.device, args.resume)
    # anns_files = sorted(os.listdir(tmp_folder))
    os.makedirs(tmp_folder, exist_ok=True)
    anns_files = os.listdir(tmp_folder)

    anns_paths = [os.path.join(tmp_folder, fname) for fname in anns_files]
    anns = CocoAnnotationsWorker.collect_to_single_ann_dict(anns_paths)

    out_file = f"pseudo_labels/{args.dataset}_cutonce_improve.json"
    with open(out_file, "w") as f:
        json.dump(anns, f, indent=2)
        print(f'dump {out_file}')
    # exit(0)

    from eval_coco_json import eval_coco_json
    if args.dataset == "coco_val2017":
        eval_dataset = "coco"
    else:
        eval_dataset = "imagenet"
    eval_coco_json(eval_dataset, out_file)
    exit(0)
    if not args.save_tmp_files:
        print('cleanup_tmp_files')
        CocoAnnotationsWorker.cleanup_tmp_files(tmp_folder)
    exit(0)
