

import os
import io
import re
import copy
import random
import json
import wget
from PIL import Image
from glob import glob
import webdataset as wds

import numpy as np
import pandas as pd
from google.cloud import storage
from packaging import version
import dataclasses
from dataclasses import dataclass, field
from typing import Dict, Optional, Sequence, List, Union
import tokenizers

import transformers
from transformers import AutoTokenizer, AutoProcessor

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torchvision import transforms
from torch.utils.data import IterableDataset

from bifrost.train.utils import image_transform
from bifrost.models.mar.utils import sample_orders, random_masking

from data.parquet import RefinedWebDataset
from datasets import load_dataset, interleave_datasets





@dataclass
class DataCollatorForImageNet(object):
    """Collate examples for supervised fine-tuning."""

    def __init__(self, 
            t2i_image_size: int,
            uni_prompting, 
            image_transform,
            # t2i_image_processor = None,
            training_args = None,
        ):

        self.training_args = training_args

        self.uni_prompting = uni_prompting
        self.transform = image_transform
        self.random_flip = training_args.random_flip
    

        self.t2i_image_processor = None 
        if training_args.vision_language_model == 'Qwen2_5_VLForConditionalGeneration' and training_args.use_clip_visual_encoder:
            self.t2i_image_processor = AutoProcessor.from_pretrained(training_args.vision_language_model_name)


        self.t2i_image_size = t2i_image_size
        self.dataset_path = json.loads(training_args.dataset_path_list)[0]

        print("########## dataset_path: ", self.dataset_path)

        if self.training_args.batch_size_t2i > 0:
            self.batch_size_t2i = self.training_args.batch_size_t2i
            import scipy.stats as stats
            mask_ratio_min = 0.7
            self.mask_ratio_generator = stats.truncnorm((mask_ratio_min - 1.0) / 0.25, 0, loc=1.0, scale=0.25)


    def prepare_t2i(self, texts):

        if (self.training_args.vae_wo_ctrlnet_training or self.training_args.vae_w_ctrlnet_training) and not self.training_args.vae_scale_by_4:
            num_visual_gen_tokens=self.training_args.num_visual_gen_tokens * 4 
        else:
            num_visual_gen_tokens=self.training_args.num_visual_gen_tokens

        input_ids_t2i, labels_t2i, attention_mask_t2i, image_position_mask_t2i, position_ids_t2i = self.uni_prompting.t2i_prompt(
            texts, 
            img_h=self.training_args.t2i_resolution,
            img_w=self.training_args.t2i_resolution,
            num_visual_gen_tokens=num_visual_gen_tokens
        )
        
        return input_ids_t2i, labels_t2i, attention_mask_t2i, image_position_mask_t2i, position_ids_t2i


    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:

        n_instances = len(instances)
        t2i_pixel_values, t2i_input_ids = [], []
        t2i_image_grid_thw = []
        t2i_image_paths = []

        for i in range(n_instances):
            instance = instances.pop(0)
            # image
            pil_image = Image.open(os.path.join(self.dataset_path, "imagenet/train", instance['image'])).convert('RGB') 
            if self.t2i_image_processor is not None:
                w, h = pil_image.size
                resolution = self.t2i_image_size
                scale = resolution / min(w, h)
                new_w, new_h = int(w * scale), int(h * scale)
                pil_image = pil_image.resize((new_w, new_h), Image.BICUBIC)
                pil_image = transforms.CenterCrop((resolution, resolution))(pil_image)
                if self.random_flip and random.random() < 0.5:
                    pil_image = transforms.RandomHorizontalFlip(p=1.0)(pil_image)
                inputs = self.t2i_image_processor.image_processor(images=pil_image, videos=None)
                pixel_values = torch.tensor(inputs['pixel_values']) # torch.Size([1152, 1176])
                image_grid_thw = torch.tensor(inputs['image_grid_thw']) # tensor([[ 1, 36, 32]])
                t2i_image_grid_thw.append(image_grid_thw[0])

            else:
                pixel_values = self.transform(pil_image, processor_type=self.training_args.vision_gen_vae, resolution=self.t2i_image_size * 16 // 14, resize_short_side_to_resolution=True, random_flip = self.training_args.random_flip)
            t2i_pixel_values.append(pixel_values.to(torch.float32))
            target = instance['image'].split("/")[0]
            t2i_input_ids.append(instance['label'])
            t2i_image_paths.append(instance['image'])

            
        batch = {}
        t2i_input_ids = t2i_input_ids[:self.training_args.batch_size_t2i]
        t2i_pixel_values = t2i_pixel_values[:self.training_args.batch_size_t2i]
        input_ids_t2i, labels_t2i, attention_mask_t2i, image_position_mask_t2i, position_ids_t2i = self.prepare_t2i(t2i_input_ids)

        if (self.training_args.vae_wo_ctrlnet_training or self.training_args.vae_w_ctrlnet_training) and not self.training_args.vae_scale_by_4:
            num_visual_gen_tokens=self.training_args.num_visual_gen_tokens * 4 
        else:
            num_visual_gen_tokens=self.training_args.num_visual_gen_tokens

        orders = sample_orders(bsz=self.training_args.batch_size_t2i, seq_len=num_visual_gen_tokens) # torch.Size([8, 256])
        mask = random_masking(
            bsz=self.training_args.batch_size_t2i, 
            seq_len=num_visual_gen_tokens, 
            orders=orders, 
            mask_ratio_generator=self.mask_ratio_generator
        )


        batch['t2i_flow'] = {
            'pixel_values': torch.stack(t2i_pixel_values),
            "input_ids": input_ids_t2i,
            "labels": labels_t2i,
            "attention_mask": attention_mask_t2i,
            'image_position_mask': image_position_mask_t2i,
            "position_ids": position_ids_t2i,
            "t2i_image_grid_thw": torch.stack(t2i_image_grid_thw) if len(t2i_image_grid_thw) > 0 else None,
            'ar_mask': mask,
            'image_labels': t2i_input_ids,
            "image_paths": t2i_image_paths,
            }

        return batch






@dataclass
class DataCollatorForSupervisedDataset(object):

    def __init__(self, 
            t2i_image_size: int,
            uni_prompting, 
            image_transform,
            training_args = None,
        ):

        self.training_args = training_args
        self.uni_prompting = uni_prompting

        if self.training_args.batch_size_t2i > 0:
            self.batch_size_t2i = self.training_args.batch_size_t2i
            import scipy.stats as stats
            mask_ratio_min = 0.7
            self.mask_ratio_generator = stats.truncnorm((mask_ratio_min - 1.0) / 0.25, 0, loc=1.0, scale=0.25)

        # clip image processor
        self.t2i_image_processor = None 
        if self.training_args.vision_language_model == 'Qwen2_5_VLForConditionalGeneration' and self.training_args.use_clip_visual_encoder:
            self.t2i_image_processor = AutoProcessor.from_pretrained(self.training_args.vision_language_model_name)


    def prepare_t2i(self, texts):

        if self.training_args.vae_wo_ctrlnet_training or self.training_args.vae_w_ctrlnet_training:
            num_visual_gen_tokens=self.training_args.num_visual_gen_tokens * 4 
        else:
            num_visual_gen_tokens=self.training_args.num_visual_gen_tokens

        input_ids_t2i, labels_t2i, attention_mask_t2i, image_position_mask_t2i, position_ids_t2i = self.uni_prompting.t2i_prompt(
            texts, 
            img_h=self.training_args.t2i_resolution,
            img_w=self.training_args.t2i_resolution,
            num_visual_gen_tokens=num_visual_gen_tokens
        )
        
        return input_ids_t2i, labels_t2i, attention_mask_t2i, image_position_mask_t2i, position_ids_t2i


    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:

        n_instances = len(instances)
        t2i_pixel_values, t2i_image_clips, t2i_input_ids, t2i_image_paths, t2i_vae = [], [], [], [], []
        t2i_image_grid_thw = []

        for i in range(n_instances):
            instance = instances.pop(0)

            h_w = int(self.training_args.num_visual_gen_tokens**0.5*2)


            # image
            if instance['jpg'] is not None:
                pil_image = instance['jpg'].convert('RGB') 
            elif instance['image'] is not None:
                pil_image = instance['image'].convert('RGB') 
            else:
                print("ERROR: no image provided")

            # text prompt
            if instance['txt'] is not None:
                text_prompt = instance['txt']
            elif instance['conversations'][0]['from'] == 'gpt':
                text_prompt = instance['conversations'][0]['value'] 
            elif instance['conversations'][1]['from'] == 'gpt':
                text_prompt = instance['conversations'][1]['value'] 
            else:
                print("ERROR: no image provided")


            w, h = pil_image.size
            resolution = self.training_args.t2i_resolution
            scale = resolution / min(w, h)
            new_w, new_h = int(w * scale), int(h * scale)
            pil_image = pil_image.resize((new_w, new_h), Image.BICUBIC)
            pil_image = transforms.CenterCrop((resolution, resolution))(pil_image)
            inputs = self.t2i_image_processor.image_processor(images=pil_image, videos=None)
            pixel_values = torch.tensor(inputs['pixel_values']) # torch.Size([1152, 1176])


            t2i_pixel_values.append(pixel_values)
            t2i_input_ids.append(text_prompt)
            t2i_image_grid_thw.append(torch.tensor([1, h_w, h_w]))
            

        batch = {}
        t2i_input_ids = t2i_input_ids[:self.training_args.batch_size_t2i]
        t2i_pixel_values = t2i_pixel_values[:self.training_args.batch_size_t2i]
        input_ids_t2i, labels_t2i, attention_mask_t2i, image_position_mask_t2i, position_ids_t2i = self.prepare_t2i(t2i_input_ids)


        if self.training_args.vae_wo_ctrlnet_training or self.training_args.vae_w_ctrlnet_training:
            num_visual_gen_tokens=self.training_args.num_visual_gen_tokens * 4 
        else:
            num_visual_gen_tokens=self.training_args.num_visual_gen_tokens


        orders = sample_orders(bsz=self.training_args.batch_size_t2i, seq_len=num_visual_gen_tokens) # torch.Size([8, 256])
        mask = random_masking(
            bsz=self.training_args.batch_size_t2i, 
            seq_len=num_visual_gen_tokens, 
            orders=orders, 
            mask_ratio_generator=self.mask_ratio_generator
        )


        batch['t2i_flow'] = {
            'pixel_values': torch.stack(t2i_pixel_values),
            "input_ids": input_ids_t2i,
            "labels": labels_t2i,
            "attention_mask": attention_mask_t2i,
            'image_position_mask': image_position_mask_t2i,
            "position_ids": position_ids_t2i,
            "t2i_image_grid_thw": torch.stack(t2i_image_grid_thw) if len(t2i_image_grid_thw) > 0 else None,
            'ar_mask': mask,
            }

        return batch








@dataclass
class DataCollatorForCLIPImageNet(object):

    def __init__(self, 
            t2i_image_size: int,
            uni_prompting, 
            image_transform,
            training_args = None,
        ):

        self.training_args = training_args
        self.uni_prompting = uni_prompting

        if self.training_args.batch_size_t2i > 0:
            self.batch_size_t2i = self.training_args.batch_size_t2i
            import scipy.stats as stats
            mask_ratio_min = 0.7
            self.mask_ratio_generator = stats.truncnorm((mask_ratio_min - 1.0) / 0.25, 0, loc=1.0, scale=0.25)


    def prepare_t2i(self, texts):

        if self.training_args.vae_wo_ctrlnet_training or self.training_args.vae_w_ctrlnet_training:
            num_visual_gen_tokens=self.training_args.num_visual_gen_tokens * 4 
        else:
            num_visual_gen_tokens=self.training_args.num_visual_gen_tokens

        input_ids_t2i, labels_t2i, attention_mask_t2i, image_position_mask_t2i, position_ids_t2i = self.uni_prompting.t2i_prompt(
            texts, 
            img_h=self.training_args.t2i_resolution,
            img_w=self.training_args.t2i_resolution,
            num_visual_gen_tokens=num_visual_gen_tokens
        )
        
        return input_ids_t2i, labels_t2i, attention_mask_t2i, image_position_mask_t2i, position_ids_t2i


    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        
        n_instances = len(instances)
        t2i_pixel_values, t2i_image_clips, t2i_input_ids, t2i_image_paths, t2i_vae = [], [], [], [], []
        t2i_image_grid_thw = []

        for i in range(n_instances):
            instance = instances.pop(0)

            h_w = int(self.training_args.num_visual_gen_tokens**0.5*2)
            npz_data = np.load(instance[4], allow_pickle=True)['data']
            clip_feat = torch.tensor(npz_data, dtype = torch.float32).reshape(self.training_args.num_visual_gen_tokens, -1)

            # image
            pil_image = Image.open(instance[1]).convert('RGB') 
            pixel_values = image_transform(
                pil_image, 
                processor_type='black-forest-labs/FLUX.1-dev', 
                resolution=int(self.training_args.t2i_resolution * 16 / 14), 
                resize_short_side_to_resolution=True, 
                resize_long_side_to_resolution=False, 
                random_flip=False
            )
            t2i_pixel_values.append(pixel_values)
            t2i_image_clips.append(clip_feat)
            t2i_input_ids.append(instance[0])
            t2i_image_grid_thw.append(torch.tensor([1, h_w, h_w]))
            t2i_image_paths.append(instance[1].split("/")[-2] + "/" + instance[1].split("/")[-1])
            

        batch = {}
        t2i_input_ids = t2i_input_ids[:self.training_args.batch_size_t2i]
        t2i_pixel_values = t2i_pixel_values[:self.training_args.batch_size_t2i]
        t2i_image_clips = t2i_image_clips[:self.training_args.batch_size_t2i]
        t2i_image_paths = t2i_image_paths[:self.training_args.batch_size_t2i]
        input_ids_t2i, labels_t2i, attention_mask_t2i, image_position_mask_t2i, position_ids_t2i = self.prepare_t2i(t2i_input_ids)


        if self.training_args.vae_wo_ctrlnet_training or self.training_args.vae_w_ctrlnet_training:
            num_visual_gen_tokens=self.training_args.num_visual_gen_tokens * 4 
        else:
            num_visual_gen_tokens=self.training_args.num_visual_gen_tokens


        orders = sample_orders(bsz=self.training_args.batch_size_t2i, seq_len=num_visual_gen_tokens) # torch.Size([8, 256])
        mask = random_masking(
            bsz=self.training_args.batch_size_t2i, 
            seq_len=num_visual_gen_tokens, 
            orders=orders, 
            mask_ratio_generator=self.mask_ratio_generator
        )


        batch['t2i_flow'] = {
            'pixel_values': torch.stack(t2i_pixel_values),
            'image_clip_embs': torch.stack(t2i_image_clips),
            "input_ids": input_ids_t2i,
            "labels": labels_t2i,
            "attention_mask": attention_mask_t2i,
            'image_position_mask': image_position_mask_t2i,
            "position_ids": position_ids_t2i,
            "t2i_image_grid_thw": torch.stack(t2i_image_grid_thw) if len(t2i_image_grid_thw) > 0 else None,
            'ar_mask': mask,
            'image_labels': t2i_input_ids,
            'image_paths': t2i_image_paths,
            }

        return batch






def make_supervised_data_module(tokenizer, uni_prompting, training_args) -> Dict:

    dataset_list = json.loads(training_args.dataset_list)
    dataset_path_list = json.loads(training_args.dataset_path_list)

    if 'imagenet1k' in dataset_list:

        data_collator_kwargs = {
                "t2i_image_size": training_args.t2i_resolution,
                "uni_prompting": uni_prompting,
                "image_transform": image_transform,
                "training_args": training_args,
            }

        if training_args.lambda_gpu:
            train_dataset = LazyImageNetCLIPDataset(training_args)
            data_collator = DataCollatorForCLIPImageNet(**data_collator_kwargs)
        else:
            train_dataset = LazyImageNetDataset(training_args)
            data_collator = DataCollatorForImageNet(**data_collator_kwargs)


    else:
        
        if 'InternVL-SA1B-Caption-WebDataset' in dataset_list:
            path = [f for f in dataset_path_list if 'InternVL-SA1B-Caption-WebDataset' in f][0]
            dataset_sa1b = load_dataset("webdataset", data_files={"train": path + "/*.tar"}, split="train", streaming=True)    

        if 'LLaVA-ReCap-CC12M' in dataset_list:
            path = [f for f in dataset_path_list if 'LLaVA-ReCap-CC12M' in f][0]
            dataset_cc12m = load_dataset(path, streaming=True)['train']    

        train_dataset = interleave_datasets([dataset_sa1b, dataset_cc12m], probabilities=[0.5, 0.5], stopping_strategy="all_exhausted")

        data_collator_kwargs = {
                "t2i_image_size": training_args.t2i_resolution,
                "uni_prompting": uni_prompting,
                "image_transform": image_transform,
                "training_args": training_args,
            }

        data_collator = DataCollatorForSupervisedDataset(**data_collator_kwargs)

    return dict(train_dataset=train_dataset,
                eval_dataset=None,
                data_collator=data_collator)




class LazyImageNetDataset(Dataset):
    """Dataset for supervised fine-tuning."""

    def __init__(self, training_args):
        super(LazyImageNetDataset, self).__init__()

        dataset_path_list = json.loads(training_args.dataset_path_list)

        json_path = os.path.join(dataset_path_list[0], "imagenet/image_label.json")
        with open(json_path, "r") as f:
            self.image_label_pairs = json.load(f)['train']

        self.length = len(self.image_label_pairs)

    def __len__(self):
        """Returns the number of samples in the dataset."""
        return self.length

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        """ 
        format: {'image': 'n03877845/n03877845_10816.JPEG', 'label': 'palace'}
        """
        return self.image_label_pairs[i]
       




class LazyImageNetCLIPDataset(Dataset):
    """Dataset for supervised fine-tuning."""

    def __init__(self, training_args):
        super(LazyImageNetCLIPDataset, self).__init__()

        dataset_path_list = json.loads(training_args.dataset_path_list)

        if training_args.t2i_resolution == 224:
            csv_path = os.path.join(dataset_path_list[0], "in1k_clip_qwen25vl_3b_224res_64tokens.csv")
        elif training_args.t2i_resolution == 448:
            csv_path = os.path.join(dataset_path_list[0], "in1k_clip_qwen25vl_3b_448res_256tokens.csv")
        elif training_args.t2i_resolution == 112:
            csv_path = os.path.join(dataset_path_list[0], "in1k_clip_qwen25vl_3b_112res_16tokens.csv")

        self.metadata = pd.read_csv(csv_path)
        try:
            del self.metadata['Unnamed: 0']
        except:
            pass 
        try:
            del self.metadata['index']
        except:
            pass 

        del self.metadata['pt_idx'], self.metadata['row_idx']
        self.metadata = self.metadata.reset_index(drop=True)

        self.metadata["clip_path"] = self.metadata["image_path"].str.replace(".JPEG", ".npz", regex=False)
        self.metadata["clip_path"] = dataset_path_list[0] + "/data_ind/" + self.metadata["clip_path"]
        self.metadata['image_path'] = os.path.dirname(dataset_path_list[0]) + "/imagenet/train/" + self.metadata["image_path"]

        self.length = len(self.metadata)

    def __len__(self):
        """Returns the number of samples in the dataset."""
        return self.length

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        return tuple(self.metadata.loc[i].values)
       