import warnings

warnings.filterwarnings("ignore")
import argparse
import os

os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
import shutil
from PIL import Image
import torch
from PIL import Image
from lang_sam import LangSAM

import torchvision.datasets as datasets
from tqdm import tqdm
import numpy as np
from typing import Any, Tuple
import PIL.Image
import cv2
from pascal_voc import PASCALVoc2007
IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp")


def build_alpha_pyramid(color, alpha, dk=1.2):
    # Written by lvmin at Stanford
    # Massive iterative Gaussian filters are mathematically consistent to pyramid.
    pyramid = []
    current_premultiplied_color = color * alpha
    current_alpha = alpha
    while True:
        pyramid.append((current_premultiplied_color, current_alpha))
        H, W, C = current_alpha.shape
        if min(H, W) == 1:
            break
        current_premultiplied_color = cv2.resize(current_premultiplied_color, (int(W / dk), int(H / dk)), interpolation=cv2.INTER_AREA)
        current_alpha = cv2.resize(current_alpha, (int(W / dk), int(H / dk)), interpolation=cv2.INTER_AREA)[:, :, None]
    return pyramid[::-1]


def pad_rgb(np_rgba_hwc_uint8):
    # Written by lvmin at Stanford
    # Massive iterative Gaussian filters are mathematically consistent to pyramid.
    np_rgba_hwc = np_rgba_hwc_uint8.astype(np.float32) / 255.0
    pyramid = build_alpha_pyramid(color=np_rgba_hwc[..., :3], alpha=np_rgba_hwc[..., 3:])
    top_c, top_a = pyramid[0]
    fg = np.sum(top_c, axis=(0, 1), keepdims=True) / np.sum(top_a, axis=(0, 1), keepdims=True).clip(1e-8, 1e32)

    for layer_c, layer_a in pyramid:
        layer_h, layer_w, _ = layer_c.shape
        fg = cv2.resize(fg, (layer_w, layer_h), interpolation=cv2.INTER_LINEAR)
        fg = layer_c + fg * (1.0 - layer_a)
    
    fg = (fg * 255.0).clip(0, 255).astype(np.uint8)
    return fg


class CustomDataset(datasets.ImageFolder):
    def __init__(self, root: str=None, transform=None, target_transform=None):
        super(CustomDataset, self).__init__(root, transform=transform, target_transform=target_transform)
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        path, target = self.samples[index]
        try:
            img = Image.open(path).convert('RGB')
        except (Image.UnidentifiedImageError, IOError) as e:
            print(f"Warning: Could not load image {path}. Skipping.")
            return None, None
        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)
        return img, target
    
    
# Function to apply a mask to an image
def apply_mask(image, mask):
    rgba_image = image.convert("RGBA")
    mask = Image.fromarray((mask * 255).astype(np.uint8))
    mask = mask.convert("L")  # Ensure mask is grayscale
    rgba_image.putalpha(mask)
    return rgba_image


# Function to split a list into parts
def split_list_into_parts(lst, num_parts):
    part_size = len(lst) // num_parts
    remainder = len(lst) % num_parts
    parts = [lst[i * part_size + min(i, remainder):(i + 1) * part_size + min(i + 1, remainder)] for i in
             range(num_parts)]
    return parts


# Argument parser function
def parse_arguments():
    parser = argparse.ArgumentParser(
        description="Generate an augmented dataset from original images and fractal patterns.")
    parser.add_argument('--dataset', type=str, required=True,
                        help='Dataset type to process.')
    parser.add_argument('--split', type=int, default=0, help='Index of the split for processing.')
    parser.add_argument('--nsplits', type=int, default=1, help='Total number of splits.')
    parser.add_argument('--dataset_split', type=str, default='train', choices=['train', 'test'],
                        help='Dataset split to use (train or test).')
    return parser.parse_args()

def load_classname_mapping(filepath):
    classname_mapping = {}
    with open(filepath, 'r') as file:
        for line in file:
            parts = line.strip().split(' ', 1)  # Split by the first space only
            if len(parts) != 2:
                print(f"Warning: Line '{line.strip()}' is not in the expected format.")
                continue
            key, value = parts
            classname_mapping[key] = value
    return classname_mapping


def process_dataset(args, data_dir, save_dir, prompt):
    no_detected_dir = os.path.join(save_dir, "no_detected")
    os.makedirs(no_detected_dir, exist_ok=True)
    text_prompt = prompt
    model = LangSAM()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    print(sorted(os.listdir(data_dir)))
    classnames_splits = split_list_into_parts(sorted(os.listdir(data_dir)), args.nsplits)
    classnames_thissplit = classnames_splits[args.split]
    classname_mapping = load_classname_mapping("./imagenet_classnames.txt")

    for class_name in classnames_thissplit:
        class_dir = os.path.join(data_dir, class_name)
        if os.path.isdir(class_dir):
            print(f"Now processing {class_name}")
            if "imagenet" in data_dir:
                text_prompt = classname_mapping[class_name]
            else:
                text_prompt = prompt
            print(f"Using prompt {text_prompt}")
            class_save_dir = os.path.join(save_dir, "cdp", class_name)
            class_indep_save_dir = os.path.join(save_dir, "cip", class_name)
            class_pad_indep_save_dir = os.path.join(save_dir, "cip_pad", class_name)
            os.makedirs(class_save_dir, exist_ok=True)
            os.makedirs(class_indep_save_dir, exist_ok=True)
            os.makedirs(class_pad_indep_save_dir, exist_ok=True)
            
            for image_filename in tqdm(sorted(os.listdir(class_dir)), desc=f"Processing images in {class_name}"):
                if image_filename.lower().endswith(IMG_EXTENSIONS):
                    img_path = os.path.join(class_dir, image_filename)
                    # Process each image in the class directory
                    image_pil = Image.open(img_path).convert('RGB')
                    masks, boxes, phrases, logits = model.predict(image_pil, text_prompt)
                    
                    if len(masks) == 0:
                        print(f"No objects of the '{text_prompt}' prompt detected in {img_path}.")
                        shutil.copy(img_path, no_detected_dir)
                    else:
                        try:
                            full_mask = np.zeros((image_pil.height, image_pil.width), dtype=np.uint8)
                            for mask in masks:
                                mask_np = mask.squeeze().cpu().numpy()
                                full_mask = np.maximum(full_mask, mask_np)
                            
                            # Apply the full mask to the image
                            class_dependent_image = apply_mask(image_pil, full_mask)
                            class_independent_image = apply_mask(image_pil, 1 - full_mask)
                            
                            # Save the images
                            base_filename = os.path.splitext(os.path.basename(img_path))[0]
                            class_dependent_image.save(os.path.join(class_save_dir, f"{base_filename}.png"))
                            class_independent_image.save(os.path.join(class_indep_save_dir, f"{base_filename}.png"))
                            
                            paddled_class_independent_image = pad_rgb(np.array(class_independent_image))
                            Image.fromarray(paddled_class_independent_image, 'RGB').save(os.path.join(class_pad_indep_save_dir, f"{base_filename}.png"))
                            
                        except Exception as e:
                            print(f"Error processing image {img_path}: {e}")

class CustomDataset(datasets.ImageFolder):
    def __init__(self, root: str=None, transform=None, target_transform=None):
        super(CustomDataset, self).__init__(root, transform=transform, target_transform=target_transform)
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        path, target = self.samples[index]
        try:
            img = Image.open(path).convert('RGB')
        except (Image.UnidentifiedImageError, IOError) as e:
            print(f"Warning: Could not load image {path}. Skipping.")
            return None, None
        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)
        return img, target

# Main function
def main():
    args = parse_arguments()
    dataset_split = args.dataset_split
    if args.dataset == 'imagenet-200':
        data_dir = "/dockerdata/imagenet-200/train"
        save_dir = "cdp_and_cip/imagenet-200/train"
        prompt='imagenet'
        loader = CustomDataset(data_dir)
        process_dataset(args, data_dir, save_dir, prompt)
    elif args.dataset =='imagenet':
        data_dir = "/dockerdata/imagenet/train"
        save_dir = "cdp_and_cip/imagenet/train"
        prompt='imagenet'
        loader = CustomDataset(data_dir)
        process_dataset(args, data_dir, save_dir, prompt)
    elif args.dataset =='bird':
        data_dir = "/datasets/CUB_200_2011/train/"
        save_dir = "cdp_and_cip/CUB_200_2011/"
        prompt='bird'
        loader = CustomDataset(data_dir)
        process_dataset(args, data_dir, save_dir, prompt)
    elif args.dataset =='aircraft':
        data_dir = "path/to/datasets/Aircraft/train/"
        save_dir = "cdp_and_cip/Aircraft/"
        prompt='aircraft'
        loader = CustomDataset(data_dir)
        process_dataset(args, data_dir, save_dir, prompt)
    elif args.dataset =="car":
        data_dir="path/to/datasets/StandfordCar/train"
        save_dir = "cdp_and_cip/StandfordCar/"
        prompt='car'
        loader = CustomDataset(data_dir)
        process_dataset(args, data_dir, save_dir, prompt)
    elif args.dataset == 'pascal':
        dataset = PASCALVoc2007(root="path/to/public-dataset/PascalVoc", split="train")
        dataset = PASCALVoc2007(root="path/to/public-dataset/PascalVoc", split="val")
        # dataset = PASCALVoc2007(root="path/to/iclr25/orig_dataset/pascal_voc2017", "test")
    else:    
        raise ValueError(f"Invalid dataset: {args.dataset}")
    
    

if __name__ == "__main__":
    main()
