# original code: https://github.com/eladhoffer/convNet.pytorch/blob/master/preprocess.py

import torch
import random
import cv2
from torchvision import datasets, transforms
from PIL import Image
import numpy as np
import albumentations as A
import glob
from collections import defaultdict
import os
from tqdm import tqdm
import re
import torch.nn.functional as F

def rgba_loader(path):
    """
    Custom loader that loads an image from a path and ensures it is in RGBA mode.
    
    :param path: Path to the image.
    :return: Image in RGBA mode.
    """
    with Image.open(path) as img:
        return img.convert('RGBA')

def crop_cdp(image):
    r, g, b, alpha = image.split()
    bbox = alpha.getbbox()
    return image.crop(bbox)

def find_position(cip_width, cip_height, fg_width, fg_height):
    max_x = max(cip_width - fg_width, 0)
    max_y = max(cip_height - fg_height, 0)
    x = random.randint(0, max_x)
    y = random.randint(0, max_y)
    return x, y

def one_hot_encode(label, num_classes):
    one_hot = torch.zeros(num_classes)
    one_hot[label] = 1
    return one_hot

class DDG(datasets.ImageFolder):
    def __init__(self, root_orig, root_cdp, root_cip, root_syncdp=None, prob_aug=0.5, prob_syn=0.25, prob_mix=0.5, num_syn=3, transform=None, beta_alpha=1.0):
        super(DDG, self).__init__(root=root_orig)
        self.rgba_loader = rgba_loader
        self.transform = transform
        self.prob_aug = prob_aug
        self.prob_syn = prob_syn
        self.prob_mix = prob_mix
        self.num_syn = num_syn
        self.beta_alpha = beta_alpha
        self.num_syn = num_syn
        self.cdp_mapping = defaultdict(list)
        self.cip_mapping = defaultdict(list)
        self.syncdp_mapping = defaultdict(list)
        if self.prob_aug > 0:
            self.cdp_mapping, self.cip_mapping, self.syncdp_mapping = self.create_mappings(root_cdp, root_syncdp, root_cip)
        self.augmentations = A.Compose([
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
        ])
        self.num_classes = len(self.classes)

    def calculate_lambdas(self, cdp1_width, cdp1_height, cdp2_width, cdp2_height):
        area_cdp1 = cdp1_width * cdp1_height
        area_cdp2 = cdp2_width * cdp2_height
        lam1 = area_cdp1 / (area_cdp1 + area_cdp2)
        lam2 = area_cdp2 / (area_cdp1 + area_cdp2)
        return lam1, lam2

    def find_position(self, bg_size, fg_size):
        bg_width, bg_height = bg_size
        fg_width, fg_height = fg_size
        x = random.randint(0, bg_width - fg_width)
        y = random.randint(0, bg_height - fg_height)
        return x, y

    def paste_image(self, base_image, to_paste_image, position):
        base_image_copy = base_image.copy()
        base_image_copy.paste(to_paste_image, position)
        return base_image_copy
    
    def create_mappings(self, root_cdp, root_syncdp, root_cip):
        cdp_mapping = defaultdict(list)
        cip_mapping = defaultdict(list)
        syncdp_mapping = defaultdict(list)
        for orig_path, _ in tqdm(self.samples, desc="Creating image mappings"):
            filename = os.path.basename(orig_path)
            directory_path = os.path.dirname(orig_path)
            classname = os.path.basename(directory_path)
            filename_base, filename_suffix = os.path.splitext(filename)
            cdp_pattern = os.path.join(root_cdp, classname, f"{filename_base}*.png")
            cip_pattern = os.path.join(root_cip, classname, f"{filename_base}*.png")
            cdp_files = sorted(glob.glob(cdp_pattern))
            cip_files = sorted(glob.glob(cip_pattern))
            if not cdp_files or not cip_files:
                raise ValueError(f"No matching cdp or cip images found for {cdp_pattern}")
            cdp_mapping[orig_path] = cdp_files
            cip_mapping[orig_path] = cip_files
            if self.prob_syn > 0:
                syn_cdp_pattern = os.path.join(root_syncdp, classname, f"{filename_base}*.png")
                syn_cdp_files = sorted(glob.glob(syn_cdp_pattern))
                if not syn_cdp_files or len(syn_cdp_files) < self.num_syn:
                    raise ValueError(f"Less than {self.num_syn} images found for {syn_cdp_pattern}")
                syncdp_mapping[orig_path].extend(syn_cdp_files[:self.num_syn])
        return cdp_mapping, cip_mapping, syncdp_mapping

    def album_augment_image(self, image):
        image_np = np.array(image)
        augmented = self.augmentations(image=image_np)
        return Image.fromarray(augmented['image'])
    
    def ddg_class_dependent_part(self, path, target_size, mode='bbox'):
        assert mode in ['bbox', 'full']
        if random.random() < self.prob_syn:
            path_cdp = random.choice(self.syncdp_mapping[path])
        else:
            path_cdp = self.cdp_mapping[path][0]
        rgba_image_cdp = self.rgba_loader(path_cdp)
        rgba_image_cdp = self.album_augment_image(rgba_image_cdp)
        # Get the size and bounding box of the CDP image
        cdp_width, cdp_height = rgba_image_cdp.size
        rgba_image_cdp_cropped = crop_cdp(rgba_image_cdp)
        # Calculate the maximum scaling factor
        if mode=='full':
            cip_width, cip_height = target_size # for whole background
            lam_max = min(cip_width / cdp_width, cip_height / cdp_height)
            lam_random = (np.random.beta(self.beta_alpha, self.beta_alpha)*0.25 + 0.75)*lam_max
            new_width = int(cdp_width * lam_random)
            new_height = int(cdp_height * lam_random)
        else:
            new_width, new_height = target_size
            
        rgba_image_cdp_resized = rgba_image_cdp_cropped.resize((new_width, new_height), Image.Resampling.LANCZOS)
        return rgba_image_cdp_resized
    
    def ddg_class_independent_part(self, path):
        path_cip = random.choice(list(self.cip_mapping.values()))
        rgb_image_cip = self.loader(path_cip[0])
        return rgb_image_cip
    
    def rand_combine(self, rgb_image_cip, rgba_image_cdp):
        # Calculate the maximum position to paste the CDP onto the CIP
        max_x = max(rgb_image_cip.width - rgba_image_cdp.width, 0)
        max_y = max(rgb_image_cip.height - rgba_image_cdp.height, 0)
        position = (random.randint(0, max_x), random.randint(0, max_y))
        # Paste the CDP image onto the CIP image
        rgb_image_cip.paste(rgba_image_cdp, position, rgba_image_cdp.split()[3])
        # Convert the result to RGB and return
        sample = rgb_image_cip.convert('RGB')
        return sample

    def mix_combine(self, rgb_image_cip, rgba_image_cdp1, rgba_image_cdp2):
        # Get sizes of the CIP and CDP images
        cip_width, cip_height = rgb_image_cip.size
        cdp1_width, cdp1_height = rgba_image_cdp1.size
        cdp2_width, cdp2_height = rgba_image_cdp2.size
        # Find positions to paste the CDP images
        x1, y1 = self.find_position((cip_width, cip_height), (cdp1_width, cdp1_height))
        x2, y2 = self.find_position((cip_width, cip_height), (cdp2_width, cdp2_height))
        
        # Calculate areas and lambda values
        lam1, lam2 = self.calculate_lambdas(cdp1_width, cdp1_height, cdp2_width, cdp2_height)
        
        # Create empty RGBA image for compositing
        rgba_empty_image_cip = Image.new('RGBA', (cip_width, cip_height), (0, 0, 0, 0))
        
        # Composite the CDP images onto the empty CIP image
        rgba_cdp1_with_empty_cip = self.paste_image(rgba_empty_image_cip.copy(), rgba_image_cdp1, (x1, y1))
        rgba_cdp2_with_empty_cip = self.paste_image(rgba_empty_image_cip.copy(), rgba_image_cdp2, (x2, y2))
        
        # Create a mask with the same size as the background image
        mask = np.zeros((cip_height, cip_width), dtype=np.uint8)
        
        # Update the mask with the alpha channels
        mask[y1:y1 + cdp1_height, x1:x1 + cdp1_width] += np.array(rgba_image_cdp1.split()[-1]) // 255
        mask[y2:y2 + cdp2_height, x2:x2 + cdp2_width] += np.array(rgba_image_cdp2.split()[-1]) // 255
        
        # Handle overlapping regions
        mask_overlap = np.where(mask == 2, 255, 0).astype(np.uint8)
        mask_overlap_image = Image.fromarray(mask_overlap, mode='L')
        mask_overlap_rgba = Image.merge('RGBA', (mask_overlap_image, mask_overlap_image, mask_overlap_image, mask_overlap_image))
   
        rgb_image_cdp1_with_overlap = Image.composite(rgba_cdp1_with_empty_cip, rgba_empty_image_cip, mask_overlap_rgba)
        rgb_image_cdp2_with_overlap = Image.composite(rgba_cdp2_with_empty_cip, rgba_empty_image_cip, mask_overlap_rgba)
        mixup_image = Image.blend(rgb_image_cdp1_with_overlap, rgb_image_cdp2_with_overlap, 0.5)
        
        rgb_image_cip.paste(rgba_image_cdp1, (x1, y1), rgba_image_cdp1.split()[3])
        rgb_image_cip.paste(rgba_image_cdp2, (x2, y2), rgba_image_cdp2.split()[3])
        rgb_image_cip.paste(mixup_image, (0, 0), mixup_image.split()[3])
        
        # Convert the result to RGB
        sample = rgb_image_cip.convert('RGB')        
        return sample, lam1, lam2
    
    def __getitem__(self, index):
        # Get the path and target from the samples
        path, target = self.samples[index]
        # Decide the augmentation and mixing strategy
        if random.random() > self.prob_aug:
            # Load the original sample
            sample = self.loader(path)
            label1, label2 = target, target
            lam1, lam2 = 1, 0
        else:
            rgb_image_cip = self.ddg_class_independent_part(path)
            if random.random() > self.prob_mix:
                rgba_image_cdp = self.ddg_class_dependent_part(path, target_size=rgb_image_cip.size, mode='full')
                sample = self.rand_combine(rgb_image_cip, rgba_image_cdp)
                label1, label2 = target, target
                lam1, lam2 = 1, 0
            else:
                path_2, label2 = random.choice(self.samples)
                rgba_image_cdp1 = self.ddg_class_dependent_part(path, target_size=rgb_image_cip.size, mode='full')
                rgba_image_cdp2 = self.ddg_class_dependent_part(path_2, target_size=rgb_image_cip.size, mode='full')
                sample, lam1, lam2 = self.mix_combine(rgb_image_cip, rgba_image_cdp1, rgba_image_cdp2)
                label1 = target
        
        # Apply transform if defined
        if self.transform:
            sample = self.transform(sample)

        # Apply target_transform if defined
        if self.target_transform:
            label1 = self.target_transform(label1)
            label2 = self.target_transform(label2)
                    
        return sample, label1, label2, lam1, lam2