import torch
# from transformers import ViltProcessor, ViltForQuestionAnswering, ViltConfig
import os
import argparse
# import evaluate
from tqdm.auto import tqdm
import numpy as np
import random

import matplotlib.pyplot as plt
from PIL import Image
import pickle

from multiprocessing import dummy as multiprocessing
import skimage.segmentation as segmentation

from segment_anything import sam_model_registry, SamPredictor
from segment_anything import build_sam, SamAutomaticMaskGenerator


def load_image(img_path, resize=None, pil=False):
    image = Image.open(img_path).convert("RGB")
    if resize is not None:
        image = image.resize((resize, resize))
    if pil:
        return image
    image = np.asarray(image)
    return image


def crop_image(image, mask):
    # Find the bounding box of the segment with True values in the mask array
    rows, cols = np.where(mask)
    top, bottom, left, right = np.min(rows), np.max(rows), np.min(cols), np.max(cols)

    # Create a new array with the same shape as the input image and set all elements to zero
    cropped_image = np.zeros_like(image)

    # Copy the segment from the input image to the new array using the bounding box coordinates
    cropped_image[top:bottom+1, left:right+1, :] = image[top:bottom+1, left:right+1, :]

    return cropped_image


def save_masked_and_cropped_image(image, mask, output_path):
    # Find the bounding box of the segment with True values in the mask array
    rows, cols = np.where(mask)
    top, bottom, left, right = np.min(rows), np.max(rows), np.min(cols), np.max(cols)

    # Create a new array with the same shape as the input image and set all elements to zero
    masked_image = np.zeros_like(image)

    # Copy only the pixels where the mask is True from the input image to the new array
    masked_image[mask] = image[mask]

    # Convert the NumPy array to a PIL Image object
    masked_image_pil = Image.fromarray(masked_image.astype(np.uint8))

    # Crop the image to the bounding box of the segment with True values in the mask array
    cropped_image_pil = masked_image_pil.crop((left, top, right+1, bottom+1))

    # Save the cropped image as a JPEG file
    cropped_image_pil.save(output_path)


def show_anns(anns):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)
    polygons = []
    color = []
    for ann in sorted_anns:
        m = ann['segmentation']
        img = np.ones((m.shape[0], m.shape[1], 3))
        color_mask = np.random.random((1, 3)).tolist()[0]
        for i in range(3):
            img[:,:,i] = color_mask[i]
        ax.imshow(np.dstack((img, m*0.35)))

         # Draw boundary lines
        boundary_mask = np.zeros_like(m, dtype=bool)
        for i in range(1, m.shape[0] - 1):
            for j in range(1, m.shape[1] - 1):
                if m[i, j]:
                    if not m[i-1, j] or not m[i+1, j] or not m[i, j-1] or not m[i, j+1]:
                        boundary_mask[i, j] = True
        ax.imshow(np.dstack((boundary_mask[..., None] * np.array([0, 0, 1]), boundary_mask * 0.7)))


def parse_args():
    parser = argparse.ArgumentParser()

    # paths and info
    parser.add_argument('--input-dir', type=str, 
                        default='../datasets/flowers/jpg',
                        help='input dir')
    parser.add_argument('--output-dir', type=str, 
                        default='../datasets/flowers/jpg/sam_seg',
                        help='output dir for segmentations')
    parser.add_argument('--class-names-filepath', type=str, 
                        default=None,
                        help='class names filepath')
    parser.add_argument('--sam-model', type=str, 
                        default='vit_h',
                        choices=['vit_h', 'vit_l', 'vit_b'],
                        help='output dir for segmentations')
    parser.add_argument('--mode', type=str, 
                        default='subdir',
                        choices=['subdir', 'flat'],
                        help='whether the files are inside subdirs or flat directory')
    parser.add_argument('--start-dir', type=int, 
                        default=0, 
                        help='start dir')
    parser.add_argument('--resize', type=int, 
                        default=None, 
                        help='resize dim if specified')
    parser.add_argument('--end-dir', type=int, 
                        default=-1, 
                        help='end dir')
    parser.add_argument('--seed', type=int, 
                        default=42, 
                        help='seed')
    parser.add_argument('--points-per-side', type=int, 
                        default=32, 
                        help='points_per_side')
    parser.add_argument('--seg-method', type=str, 
                        default='sam',
                        choices=['sam', 'slic'], 
                        help='points_per_side')
    parser.add_argument('--n-segments', type=int, 
                        default=15, 
                        help='num segments')
    parser.add_argument('--compactness', type=float, 
                        default=20, 
                        help='compactness')
    parser.add_argument('--overwrite',
                        default=False, 
                        action='store_true',
                        help='overwrite previous')
    
    
    return parser

def main():
    parser = parse_args()
    args = parser.parse_args()

    print('\n---argparser---:')
    for arg in vars(args):
        print(arg, getattr(args, arg), '\t', type(arg))

    # if args.sweep == 1:
    #     args.track = True
    #     # args.shuffle = True
    #     args.freeze = True
    #     # args.block_cuda = True
    os.makedirs(args.output_dir, exist_ok=True)

    # Set the seed for reproducibility
    if args.seed != -1:
        # Torch RNG
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)
        # Python RNG
        np.random.seed(args.seed)
        random.seed(args.seed)

    # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # sam_checkpoint = "actions/sam/pt_models/sam_vit_h_4b8939.pth"


    # sam = build_sam(checkpoint=sam_checkpoint).to(device)
    sam_model_names = {'vit_h': 'actions/segmentation/pt_models/sam_vit_h_4b8939.pth',
                       'vit_l': 'actions/segmentation/pt_models/sam_vit_l_0b3195.pth',
                       'vit_b': 'actions/segmentation/pt_models/sam_vit_b_01ec64.pth'}
    sam = sam_model_registry[args.sam_model](checkpoint=sam_model_names[args.sam_model]).to(device)
    mask_generator = SamAutomaticMaskGenerator(sam, 
                        points_per_side=args.points_per_side)
    mask_generator_32 = SamAutomaticMaskGenerator(sam, 
                        points_per_side=32)
    mask_generator_64 = SamAutomaticMaskGenerator(sam, 
                        points_per_side=64)

    total = 0
    
    # for dirname in tqdm(os.listdir(args.input_dir)):
    # class_idxs = [235, 696]
    # for class_idx in class_idxs:
        # dirname = str(class_idx)
    failed_paths = []
    if args.mode == 'subdir':
        if args.class_names_filepath is None:
            dirnames = list(sorted(os.listdir(args.input_dir)))
        else:
            with open(args.class_names_filepath, 'rt') as input_file:
                dirnames = [line.strip()[1:] for line in input_file.readlines()]
        if args.end_dir != -1:
            dirnames = dirnames[args.start_dir:args.end_dir]
        else:
            dirnames = dirnames[args.start_dir:]
        for dirname in tqdm(dirnames):
            # for cls_dir in tqdm(os.listdir(os.path.join(input_dir, split, dirname))):
            src_cls_dir = os.path.join(args.input_dir, dirname)
            dest_cls_dir = os.path.join(args.output_dir, dirname)
            os.makedirs(dest_cls_dir, exist_ok=True)

            for filename in tqdm(os.listdir(src_cls_dir)):
                # image_path = 'data/food-101/images/baby_back_ribs/2432.jpg'
                # image_path = 'data/flowers/jpg/image_00001.jpg'
                image_path = os.path.join(src_cls_dir, filename)
                output_path = os.path.join(dest_cls_dir, filename + '.pkl')
                if not args.overwrite:
                    if os.path.exists(output_path):
                        continue
                        # with open(output_path, 'rb') as input_file:
                        #     data = pickle.load(input_file)
                        # if len(data) != 0:
                        #     continue
                # print('Loading image ...')
                # img = load_image(image_path)
                try:
                    img = load_image(image_path, resize=args.resize)
                except:
                    print('load failed', image_path + '\n\n\n')
                    continue
                if args.seg_method == 'sam':
                    try:
                        masks = mask_generator.generate(img)
                    except:
                        print('failed ', image_path)
                        # import pdb
                        # pdb.set_trace()
                        # exit()
                        masks = []
                    if len(masks) == 0:
                        try:
                            masks = mask_generator_32.generate(img)
                        except:
                            print('failed ', image_path)
                            # import pdb
                            # pdb.set_trace()
                            # exit()

                    if len(masks) == 0:
                        try:
                            masks = mask_generator_64.generate(img)
                        except:
                            print('failed ', image_path)
                            # import pdb
                            # pdb.set_trace()
                            # exit()
                    
                    if len(masks) == 0:
                        print('did not find masks', image_path)
                        failed_paths.append(image_path)
                        with open(os.path.join(dest_cls_dir, 'failed_paths.txt'), 'a') as failed_file:
                            failed_file.write(image_path)

                    masks_seg = np.array([mask['segmentation'] for mask in masks])
                else:
                    masks = segmentation.slic(
                        img, n_segments=args.n_segments, 
                        compactness=args.compactness,
                        sigma=1.)
                    masks_seg = masks
                    
                
                with open(output_path, 'wb') as output_file:
                    pickle.dump(masks_seg, output_file)
    else:  # flat
        src_cls_dir = args.input_dir
        dest_cls_dir = args.output_dir
        for filename in tqdm(os.listdir(src_cls_dir)):
            # image_path = 'data/food-101/images/baby_back_ribs/2432.jpg'
            # image_path = 'data/flowers/jpg/image_00001.jpg'
            if not filename.endswith('.jpg'):
                continue
            image_path = os.path.join(src_cls_dir, filename)
            output_path = os.path.join(dest_cls_dir, filename + '.pkl')
            if not args.overwrite:
                if os.path.exists(output_path):
                    continue
                    # with open(output_path, 'rb') as input_file:
                    #     data = pickle.load(input_file)
                    # if len(data) != 0:
                    #     continue
            # print('Loading image ...')
            # img = load_image(image_path)
            img = load_image(image_path, resize=args.resize)
            if args.seg_method == 'sam':
                try:
                    masks = mask_generator.generate(img)
                except:
                    print('failed ', image_path)
                    import pdb
                    pdb.set_trace()
                    exit()
                if len(masks) == 0:
                    try:
                        masks = mask_generator_32.generate(img)
                    except:
                        print('failed ', image_path)
                        import pdb
                        pdb.set_trace()
                        exit()

                if len(masks) == 0:
                    try:
                        masks = mask_generator_64.generate(img)
                    except:
                        print('failed ', image_path)
                        import pdb
                        pdb.set_trace()
                        exit()

                if len(masks) == 0:
                    print('did not find masks', image_path)
                    failed_paths.append(image_path)
                    with open(os.path.join(dest_cls_dir, 'failed_paths.txt'), 'a') as failed_file:
                        failed_file.write(image_path)

                masks_seg = np.array([mask['segmentation'] for mask in masks])
                
            else:
                masks = segmentation.slic(
                    img, n_segments=args.n_segments, 
                    compactness=args.compactness,
                    sigma=1.)
                masks_seg = masks
            # import pdb
            # pdb.set_trace()
            
            with open(output_path, 'wb') as output_file:
                pickle.dump(masks_seg, output_file)

        
if __name__ == '__main__':
    main()