import os
import json
import pandas as pd
from datasets import Dataset, DatasetDict, Image, load_dataset, concatenate_datasets
from constants.const import class_available,theme_available
from collections import defaultdict

def NudityKeywordDataset(image_dir, keyword_path):

    # Load keywords
    with open(keyword_path, 'r') as f:
        keywords = json.load(f)
    assert isinstance(keywords, list)

    # Get sorted list of image paths
    all_images = sorted([
        f for f in os.listdir(image_dir) if f.endswith('.png')
    ])
    assert len(all_images) >= len(keywords), "Not enough images for keywords"

    # Trim to make number of images divisible by number of keywords
    num_keywords = len(keywords)
    max_images = (len(all_images) // num_keywords) * num_keywords
    all_images = all_images[:max_images]

    # Evenly assign keywords to images (deterministically)
    per_keyword = len(all_images) // num_keywords
    data = []
    for i, keyword in enumerate(keywords):
        assigned_images = all_images[i * per_keyword : (i + 1) * per_keyword]
        for img_name in assigned_images:
            data.append({
                "image": os.path.join(image_dir, img_name),
                "text": keyword,
                'dist': 'forget'
            })

    forget_dataset = Dataset.from_pandas(pd.DataFrame(data)).cast_column("image", Image())

    retain_dataset = load_dataset('sjki928/coco2017caption',split='train')
    retain_dataset = retain_dataset.add_column('dist',['retain' for _ in range(len(retain_dataset))])

    dataset = concatenate_datasets([forget_dataset,retain_dataset])
    return dataset


def UnlearnCanvasDataset(image_dir, target_name):
    """
    Creates the forget and retain datasets for a given unlearning task.
    
    Args:
        unlearn_type (str): 'class' or 'style'
        target_name (str): The specific class or style name to unlearn.
    
    Returns:
        A tuple of (forget_dataset, retain_dataset).
    """
    if target_name in class_available:
        unlearn_type = 'class'  
    elif target_name in theme_available: 
        unlearn_type = 'style'
    else:
        ValueError('you suck')

    # Initialize lists to store our data
    forget_data = defaultdict(list)
    retain_data = defaultdict(list)
    
    # Get the list of all styles and classes from your directory structure
    
    # Loop through each style and each class to build the datasets
    for style in theme_available:
        
        for class_name in class_available:
            target_dir = os.path.join(image_dir, style, f'{class_name}')
            if style != 'Seed_Images':
                prompt = f"An {class_name} image in {style.replace('_', ' ')} style."
            else:
                prompt = f"An {class_name} image in Photo style."
            
            # --- Logic to separate data into forget and retain sets ---
            if unlearn_type == 'style':
                for image in os.listdir(target_dir):
                    image_path = os.path.join(target_dir,image)
                    if style == target_name:
                        forget_data['image'].append(image_path)
                        forget_data['text'].append(prompt)
                        forget_data['dist'].append('forget')
                    else:
                        retain_data['image'].append(image_path)
                        retain_data['text'].append(prompt)
                        retain_data['dist'].append('retain')
            elif unlearn_type == 'class':
                for image in os.listdir(target_dir):
                    image_path = os.path.join(target_dir,image)
                    if class_name == target_name:
                        forget_data['image'].append(image_path)
                        forget_data['text'].append(prompt)
                        forget_data['dist'].append('forget')
                    else:
                        retain_data['image'].append(image_path)
                        retain_data['text'].append(prompt)
                        retain_data['dist'].append('retain')

            
    # Create the Hugging Face Dataset objects
    forget_dataset = Dataset.from_dict(forget_data).cast_column("image", Image())
    retain_dataset = Dataset.from_dict(retain_data).cast_column("image", Image())

    dataset = concatenate_datasets([forget_dataset,retain_dataset])
    return  dataset

def UnlearnCanvasDataset_classifier(image_dir,unlearn_type):

    # Initialize lists to store our data
    data = defaultdict(list)
    
    style_label = {k:v for v,k in enumerate(theme_available)}
    
    class_label = {k:v for v,k in enumerate(class_available)}
    
    # Loop through each style and each class to build the datasets
    for style in theme_available:
        for class_name in class_available:
            target_dir = os.path.join(image_dir, style, f'{class_name}')
            if style != 'Seed_Images':
                prompt = f"A {class_name} image in {style.replace('_', ' ')} style."
            else:
                prompt = f"A {class_name} image in Photo style."
            
            # --- Logic to separate data into forget and retain sets ---
            for image in os.listdir(target_dir):
                image_path = os.path.join(target_dir,image)
                data['image'].append(image_path)
                data['text'].append(prompt)
                data['style_dist'].append(style_label[style])
                data['class_dist'].append(class_label[class_name])
      
    # Create the Hugging Face Dataset objects
    dataset = Dataset.from_dict(data).cast_column("image", Image())
    label = style_label if unlearn_type == 'style' else class_label
    return  dataset, label


def UnlearnCanvasDataset_classifier_text_only(image_dir, unlearn_type):

    # Initialize lists to store our data
    data = defaultdict(list)
    
    style_label = {k:v for v,k in enumerate(theme_available)}
    
    class_label = {k:v for v,k in enumerate(class_available)}
    
    # Loop through each style and each class to build the datasets
    for style in theme_available:
        for class_name in class_available:
            if style != 'Seed_Images':
                prompt = f"A {class_name} image in {style.replace('_', ' ')} style."
            else:
                prompt = f"A {class_name} image in Photo style."
            
            # --- Logic to separate data into forget and retain sets ---
            data['text'].append(prompt)
            data['style_dist'].append(style_label[style])
            data['class_dist'].append(class_label[class_name])
      
    # Create the Hugging Face Dataset objects
    dataset = Dataset.from_dict(data).cast_column("image", Image())
    label = style_label if unlearn_type == 'style' else class_label
    return  dataset, label

# Example usage for unlearning the 'Cats' class
# forget_ds_cats, retain_ds_cats = create_unlearning_dataset('class', 'Cats')

# Example usage for unlearning the 'Bricks' style
# forget_ds_bricks, retain_ds_bricks = create_unlearning_dataset('style', 'Bricks')

# Now you can use these datasets for your unlearning process.
 # Example usage
if __name__ == "__main__":
    import numpy as np
    import torch
    from torchvision import transforms
    from transformers import CLIPTextModel, CLIPTokenizer
    import random
    # Assuming you have your dataset_dict ready
    
    # Method 1: Combined dataset
    print("Method 1: Combined Dataset")
    
    # Method 2: Separate dataloaders
    print("Method 2: Separate DataLoaders")
    # dataset = NudityKeywordDataset(
    #     image_dir="/workspace/ICCV_dataset/OPC/nudity",
    #     keyword_path="/workspace/Unlearning/SD/generalized_nudity_prompt_keywords.json",
    # )

    # dataset = UnlearnCanvcasDataset('/workspace/unlearncanvas', 'Bricks')

    dataset, target_ind = UnlearnCanvasDataset_classifier('/workspace/unlearncanvas','Bricks')
    tokenizer = CLIPTokenizer.from_pretrained(
        "CompVis/stable-diffusion-v1-4", subfolder="tokenizer", revision=None
    )
    # Preprocessing the datasets.
    # We need to tokenize inputs and targets.

    image_column = 'image'
    caption_column = 'text'

    # Preprocessing the datasets.
    # We need to tokenize input captions and transform the images.
    def tokenize_captions(examples, is_train=True):
        captions = []
        for caption in examples[caption_column]:
            if isinstance(caption, str):
                captions.append(caption)
            elif isinstance(caption, (list, np.ndarray)):
                # take a random caption if there are multiple
                captions.append(random.choice(caption) if is_train else caption[0])
            else:
                raise ValueError(
                    f"Caption column `{caption_column}` should contain either strings or lists of strings."
                )
        inputs = tokenizer(
            captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
        )
        return inputs.input_ids

    # Get the specified interpolation method from the args
    interpolation = transforms.InterpolationMode.LANCZOS

    # Raise an error if the interpolation method is invalid
    if interpolation is None:
        raise ValueError(f"Unsupported interpolation mode {'lanczos'}.")

    # Data preprocessing transformations
    train_transforms = transforms.Compose(
        [
            transforms.Resize(512, interpolation=interpolation),  # Use dynamic interpolation method
            transforms.CenterCrop(512),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),
        ]
    )

    def preprocess_train(examples):
        images = [image.convert("RGB") for image in examples[image_column]]
        examples["pixel_values"] = [train_transforms(image) for image in images]
        examples["input_ids"] = tokenize_captions(examples)
        return examples

    dataset = dataset.with_transform(preprocess_train)
    
    def collate_fn(examples):
        pixel_values = torch.stack([example["pixel_values"] for example in examples])
        pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
        input_ids = torch.stack([example["input_ids"] for example in examples])
        return {"pixel_values": pixel_values, "input_ids": input_ids, "dist": [example["input_ids"] for example in examples]}

    def concat_collate_fn(examples):
        retain_batch = defaultdict(list)
        forget_batch = defaultdict(list)
        
        for example in examples:
            if example['dist'] == 'retain':
                retain_batch['pixel_values'].append(example["pixel_values"])
                retain_batch['input_ids'].append(example['input_ids'])
            else:
                forget_batch['pixel_values'].append(example["pixel_values"])
                forget_batch['input_ids'].append(example['input_ids'])
        
        if retain_batch['input_ids']:
            retain_batch['pixel_values'] = torch.stack(retain_batch['pixel_values'])
            retain_batch['pixel_values'] = retain_batch['pixel_values'].to(memory_format=torch.contiguous_format).float()

            retain_batch['input_ids'] = torch.stack(retain_batch['input_ids'])
        else:
            retain_batch = None

        if forget_batch['input_ids']:
            forget_batch['pixel_values'] = torch.stack(forget_batch['pixel_values'])
            forget_batch['pixel_values'] = forget_batch['pixel_values'].to(memory_format=torch.contiguous_format).float()

            forget_batch['input_ids'] = torch.stack(forget_batch['input_ids'])
        else:   
            forget_batch = None
        
        return (retain_batch, forget_batch)    

    # def concat_collate_fn(examples):
    #     # dist 값들을 리스트로 추출
    #     dists = [example['dist'] for example in examples]

    #     # retain과 forget 데이터의 인덱스를 찾기
    #     retain_indices = [i for i, d in enumerate(dists) if d == 'retain']
    #     forget_indices = [i for i, d in enumerate(dists) if d == 'forget']

    #     # retain 배치 생성
    #     if retain_indices:
    #         retain_examples = [examples[i] for i in retain_indices]
    #         retain_pixel_values = torch.stack([ex["pixel_values"] for ex in retain_examples])
    #         retain_input_ids = torch.stack([ex["input_ids"] for ex in retain_examples])
    #         retain_batch = {
    #             "pixel_values": retain_pixel_values.to(memory_format=torch.contiguous_format).float(),
    #             "input_ids": retain_input_ids
    #         }
    #     else:
    #         retain_batch = None

    #     # forget 배치 생성
    #     if forget_indices:
    #         forget_examples = [examples[i] for i in forget_indices]
    #         forget_pixel_values = torch.stack([ex["pixel_values"] for ex in forget_examples])
    #         forget_input_ids = torch.stack([ex["input_ids"] for ex in forget_examples])
    #         forget_batch = {
    #             "pixel_values": forget_pixel_values.to(memory_format=torch.contiguous_format).float(),
    #             "input_ids": forget_input_ids
    #         }
    #     else:
    #         forget_batch = None

    #     return retain_batch, forget_batch

    # DataLoaders creation:
    # retain_dataloader = torch.utils.data.DataLoader(
    #     retain_dataset,
    #     shuffle=True,
    #     collate_fn=collate_fn,
    #     batch_size=2,
    #     num_workers=2,
    # )

    # forget_dataloader = torch.utils.data.DataLoader(
    #     forget_dataset,
    #     shuffle=True,
    #     collate_fn=collate_fn,
    #     batch_size=2,
    #     num_workers=2,
    # )

    train_dataloader = torch.utils.data.DataLoader(
        dataset,
        shuffle=True,
        collate_fn=concat_collate_fn,
        batch_size=2,
        num_workers=2,
    )
    

    for step, (retain_batch, forget_batch) in enumerate(train_dataloader):
        break

    print("Dataset loaders created successfully!")