import os
import pandas as pd
import torch
import clip

from PIL import Image
from torch.utils.data import Dataset
from tqdm import tqdm
import numpy as np 
from random import randrange

from transformers import AutoProcessor, Kosmos2ForConditionalGeneration
from torchvision.transforms import Compose, TenCrop

# os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

class MemesLogoDataset(Dataset):
    def __init__(self, args, root_folder, root_folder_logos, dataset, split='train', image_size=224, fast=True, factor_shrink=5, transparency = 1.0, logos_idx=0, pos_logo_x=0.0, pos_logo_y=0.0):
        super(MemesLogoDataset, self).__init__()
        self.root_folder_logos = root_folder_logos  
        self.factor_shrink = factor_shrink
        self.transparency = transparency
        self.logos_idx = logos_idx
        self.pos_logo_x = pos_logo_x
        self.pos_logo_y = pos_logo_y
        self.args = args 


        self.root_folder = root_folder
        self.dataset = dataset
        self.split = split

        self.image_size = (512, 450)
        self.fast = fast

        self.info_file = os.path.join(root_folder, dataset, f'labels/{dataset}_info.csv')
        self.df = pd.read_csv(self.info_file)
        self.df = self.df[self.df['split'] == self.split].reset_index(drop=True)
        float_cols = self.df.select_dtypes(float).columns
        self.df[float_cols] = self.df[float_cols].fillna(-1).astype('Int64')

        self.choose_random_sample()
        self.build_dataset()

    def build_dataset(self):
        logos_fns = [os.path.join(self.root_folder_logos, file) for file in os.listdir(self.root_folder_logos) if file.endswith(".jpg")]
        logos_fns = sorted(logos_fns)
        start_idx = self.logos_idx * 5000
        end_idx = min(start_idx + 5000, len(logos_fns))
        logos_fns = logos_fns[start_idx:end_idx]

        print("="*20)
        print("Starting index:", start_idx, "Ending index:", end_idx)
        print("="*20)
        
        self.all_filenames = [] 
        self.texts = []
        self.labels = [] 

    
        self.labels_unique = [] 
        for _, row in tqdm(self.df.iterrows(), total=self.df.shape[0]): 
            self.labels_unique.append(row['label'])
            for logo_fn in logos_fns: 


                if self.dataset == 'hmc':
                    image_fn = row['img'].split('/')[1]
                else:
                    image_fn = row['image']

                if row['text'] == 'nothing':
                    txt = 'null'
                else:
                    txt = row['text']

                self.all_filenames.append((os.path.join(self.root_folder, self.dataset, 'img', image_fn), logo_fn))
                self.labels.append(row['label'])
                self.texts.append(txt)
            
    def choose_random_sample(self, num=128):
        #get only the rows where label is 1 
        # self.df = self.df[self.df['label'] == 1]
        self.df = self.df.sample(n=num, random_state=0) 

    def load_attack_file(self, paste_attack_file): 
        img = Image.open(paste_attack_file).convert("RGBA")
        
        image_array = np.array(img)
        offwhite_condition = (image_array[:, :, :3] > 200).all(axis=2)
        image_array[offwhite_condition] = [255, 255, 255, 0]
        img = Image.fromarray(image_array)

        return img

    def past_attack(self, img, past_attack_f, past_attack_loc):

        transparency = int(self.transparency * 255)

        image = img.convert('RGBA')
        watermark = past_attack_f.resize((image.size[0]//self.factor_shrink, image.size[1]//self.factor_shrink))
        layer = Image.new('RGBA', image.size, (0, 0, 0, 0))

        if past_attack_loc == "top_left":
            img_w = 0 
            img_h = 0

        elif past_attack_loc == "top_right":
            img_w = image.size[0] - watermark.size[0]
            img_h = 0
        
        elif past_attack_loc == "bottom_left":
            img_w = 0
            img_h = image.size[1] - watermark.size[1]
        
        elif past_attack_loc == "bottom_right":
            img_w = image.size[0] - watermark.size[0]
            img_h = image.size[1] - watermark.size[1]
        
        else: 
            raise ValueError(f"Invalid past_attack_loc: {past_attack_loc}")


        layer.paste(watermark, (img_w, img_h))

        # Create a copy of the layer
        layer2 = layer.copy()

        # Put alpha on the copy
        layer2.putalpha(transparency)
        # merge layers with mask
        layer.paste(layer2, layer)
        result = Image.alpha_composite(image, layer)

        return result.convert("RGB")

    def __len__(self):
        return len(self.all_filenames)

    def __getitem__(self, idx):

        logo_fn = self.all_filenames[idx][1]
        subject_fn = self.all_filenames[idx][0]
        text = self.texts[idx]
        label = self.labels[idx]
        meme_idx = idx

        subject_image = Image.open(subject_fn).convert('RGB')\
            .resize((self.image_size, self.image_size))
        
        final_image =  self.past_attack(logo_fn, subject_image)
        # final_image.save(f"test_{idx}.jpg")
        # quit()

        item = {
            'image': final_image,
            'text': text,
            'label': label,
            'origin_text': text,
            'idx_meme': meme_idx,
            'logo_fns': logo_fn

        }

        return item


class MemesDataset(Dataset):
    
    def __init__(self, args, root_folder, dataset, split='train', image_size=224, fast=True):
        super(MemesDataset, self).__init__()
        self.root_folder = root_folder
        self.dataset = dataset
        self.split = split
        self.factor_shrink = args.factor_shrink
        self.transparency = args.transparency

        self.past_attack_file_locations = args.past_attack_file_locations
        self.paste_attack_file = []
        self.args = args

        self.image_size = image_size
        self.fast = fast

        self.info_file = os.path.join(root_folder, dataset, f'labels/{dataset}_info.csv')
        self.df = pd.read_csv(self.info_file)
        self.df = self.df[self.df['split'] == self.split].reset_index(drop=True)
        float_cols = self.df.select_dtypes(float).columns
        self.df[float_cols] = self.df[float_cols].fillna(-1).astype('Int64')

        if self.fast:
            self.embds = torch.load(f'{self.root_folder}/{self.dataset}/clip_embds/{split}_no-proj_output.pt', map_location='cuda')
            self.embdsDF = pd.DataFrame(self.embds)

            assert len(self.embds) == len(self.df)


        if args.paste_attack_file is not None:
            for past_attack_f in args.paste_attack_file: 
                self.paste_attack_file.append(self.load_attack_file(past_attack_f))


    def load_attack_file(self, paste_attack_file): 
        img = Image.open(paste_attack_file).convert("RGBA")
        
        image_array = np.array(img)
        offwhite_condition = (image_array[:, :, :3] > 200).all(axis=2)
        image_array[offwhite_condition] = [255, 255, 255, 0]

        return Image.fromarray(image_array)


    def past_attack(self, img, past_attack_f, past_attack_loc):

        transparency = int(self.transparency * 255)

        image = img.convert('RGBA')
        watermark = past_attack_f.resize((image.size[0]//self.factor_shrink, image.size[1]//self.factor_shrink))
        layer = Image.new('RGBA', image.size, (0, 0, 0, 0))

        if past_attack_loc == "top_left":
            img_w = 0
            img_h = int(0.2 * image.size[1])

        elif past_attack_loc == "top_right":
            img_w = image.size[0] - watermark.size[0]
            img_h = int(0.2 * image.size[1])
        
        elif past_attack_loc == "bottom_left":
            img_w = 0
            img_h = (image.size[1] - watermark.size[1]) - int(0.2 * image.size[1])
        
        elif past_attack_loc == "bottom_right":
            img_w = image.size[0] - watermark.size[0]
            img_h = (image.size[1] - watermark.size[1]) - int(0.2 * image.size[1])
        
        else: 
            raise ValueError(f"Invalid past_attack_loc: {past_attack_loc}")


        layer.paste(watermark, (img_w, img_h))

        # Create a copy of the layer
        layer2 = layer.copy()

        # Put alpha on the copy
        layer2.putalpha(transparency)
        # merge layers with mask
        layer.paste(layer2, layer)
        result = Image.alpha_composite(image, layer)

        return result.convert("RGB")

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        if row['text'] == 'nothing':
            txt = 'null'
        else:
            txt = row['text']

        if self.fast:
            embd_idx = self.embdsDF.loc[self.embdsDF['idx_meme'] == row['id']].index[0]
            embd_row = self.embds[embd_idx]

            # use CLIP pre-calculated embeddings as image and text inputs
            image = embd_row['image']
            text = embd_row['text']

        else:
            # use raw image and text inputs
            if self.dataset == 'hmc':
                image_fn = row['img'].split('/')[1]
            else:
                image_fn = row['image']

            # image_original = Image.open(f"{self.root_folder}/{self.dataset}/img/{image_fn}").convert('RGB')
            image = Image.open(f"{self.root_folder}/{self.dataset}/img/{image_fn}").convert('RGB')\
                .resize((self.image_size, self.image_size))

            if len(self.paste_attack_file) > 0:
                for past_attack_f, past_attack_loc in zip(self.paste_attack_file, self.past_attack_file_locations):
                    image = self.past_attack(image, past_attack_f, past_attack_loc)
            text = txt

        item = {
            'image': image,
            'text': text,
            'label': row['label'],
            'idx_meme': row['id'],
            'origin_text': txt, 
            'image_fn': image_fn,
            'image_fn_full': f"{self.root_folder}/{self.dataset}/img/{image_fn}"
        }

        return item


class MemesCollator(object):
    def __init__(self, args):
        self.args = args
        if not args.fast_process:
            _, self.clip_preprocess = clip.load("ViT-L/14", device="cuda", jit=False)

        self.crop_transform = Compose([
            TenCrop(int(args.image_size*0.82)), 
        ])

    def __call__(self, batch):
        labels = torch.LongTensor([item['label'] for item in batch])
        idx_memes = torch.LongTensor([item['idx_meme'] for item in batch])

        text_input = []
        for el in batch:
            text_input.append(clip.tokenize(f'{"a photo of $"} , {el["origin_text"]}', context_length=77,
                                            truncate=True))

        enh_texts = torch.cat([item for item in text_input], dim=0)

        simple_prompt = clip.tokenize('a photo of $', context_length=77).repeat(labels.shape[0], 1)


        batch_new = {'labels': labels,
                     'idx_memes': idx_memes,
                     'enhanced_texts': enh_texts,
                     'simple_prompt': simple_prompt,
                     "image_fn_full": [item['image_fn_full'] for item in batch]
                     }

        if self.args.fast_process:
            images_emb = torch.cat([item['image'] for item in batch], dim=0)
            texts_emb = torch.cat([item['text'] for item in batch], dim=0)

            batch_new['images'] = images_emb
            batch_new['texts'] = texts_emb

        elif self.args.random_crop:

            for i in range(10):
                batch_new[f'pixel_values_{i}'] = [] 

            for i, sample_content in enumerate(batch):
                
                img = sample_content['image']
                cropped_images = self.crop_transform(img)
                ss = sample_content["image_fn"]
                img.save(f"crops/crop_{i}_{ss}.jpg")
                for crop_i, crop in enumerate(cropped_images):
                    crop.save(f"crops/crop_{i}_{crop_i}_{ss}.jpg")

                cropped_images = [self.clip_preprocess(cropped_image) for cropped_image in cropped_images]

                for idx, crop_img in enumerate(cropped_images):
                    batch_new[f'pixel_values_{idx}'].append(crop_img)
                
            for i in range(10):
                batch_new[f'pixel_values_{i}'] = torch.stack(batch_new[f'pixel_values_{i}'])

        elif self.args.owlv2: 
            images = [item["image"] for item in batch]
            target_sizes = torch.Tensor([[image.size[::-1]] for image in images]).squeeze()

            texts = [["A photo of a logo"] for _ in range(len(images))]
            inputs = self.owl_processor(images = images, text = texts, return_tensors="pt")

            inputs = {name: tensor.cuda() for name, tensor in inputs.items()}
            outputs = self.owl_model(**inputs)
            results = self.owl_processor.post_process_object_detection(outputs=outputs, target_sizes=target_sizes, threshold=0.1)

            final_images = []
            for image, result in zip(images, results): 
                image = np.array(image)

                boxes, scores, labels = result["boxes"], result["scores"], result["labels"]
                for box, score, label in zip(boxes, scores, labels):
                    if score < 0.1: 
                        continue
                    x1, y1, x2, y2 = box
                    x1 = int(x1)
                    x2 = int(x2)
                    y1 = int(y1)
                    y2 = int(y2)

                    area = (y2-y1) * (x2-x1)
                    ratio = area / (image.shape[0] * image.shape[1])

                    if ratio > 0.2:
                        continue
                    
                    image[y1:y2, x1:x2] = 0

                final_images.append(self.clip_preprocess(Image.fromarray(image)).unsqueeze(0))

            final_images = torch.cat(final_images, dim=0)
            batch_new['pixel_values'] = final_images

        elif self.args.kosmos_mask: 
            images = [item["image"] for item in batch]
            texts = ["<grounding> Is there a logo in the image? if there is, where is it?" for _ in range(len(images))]
            inputs = self.kosmos_processor(text=texts, images=images, return_tensors="pt")

            generated_ids = self.kosmos_model.generate(
                pixel_values=inputs["pixel_values"].cuda(),
                input_ids=inputs["input_ids"].cuda(),
                attention_mask=inputs["attention_mask"].cuda(),
                image_embeds=None,
                image_embeds_position_mask=inputs["image_embeds_position_mask"].cuda(),
                use_cache=True,
                max_new_tokens=64,
            )

            generated_texts = self.kosmos_processor.batch_decode(generated_ids, skip_special_tokens=True)
            images_entities = [self.kosmos_processor.post_process_generation(generated_text)[1] for generated_text in generated_texts]

            final_images = []
            for idx, image, image_entity in zip(range(len(images)), images, images_entities):

                boxes = [] 
                for entity in image_entity:
                    entity_name, (start, end), bbox = entity
                    if start == end:
                        # skip bounding bbox without a `phrase` associated
                        continue
                    boxes.extend(bbox)


                image = np.array(image)
                for box in boxes: 
                    x1, y1, x2, y2 = box
                    x1 = int(x1 * image.shape[1])
                    x2 = int(x2 * image.shape[1])
                    y1 = int(y1 * image.shape[0])
                    y2 = int(y2 * image.shape[0])

                    # image = cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2)
                    area = (y2-y1) * (x2-x1)
                    image_area = image.shape[0] * image.shape[1]

                    if area > 0.8 * image_area: 
                        continue

                    image[y1:y2, x1:x2] = 0

                final_images.append(self.clip_preprocess(Image.fromarray(image)).unsqueeze(0))
            
            final_images = torch.cat(final_images, dim=0)
            batch_new['pixel_values'] = final_images

        else: 
            img = []
            texts = []
            for item in batch:
                pixel_values = self.clip_preprocess(item['image']).unsqueeze(0)
                img.append(pixel_values)

            pixel_values = torch.cat([item for item in img], dim=0)
            batch_new['pixel_values'] = pixel_values

        texts = [] 
        for item in batch: 
            text = clip.tokenize(item['text'], context_length=77, truncate=True)
            texts.append(text)

        texts = torch.cat([item for item in texts], dim=0)
        batch_new['texts'] = texts

        if 'logo_fns' in batch[0]:
            logo_fns = [item['logo_fns'] for item in batch]
            batch_new['logo_fns'] = logo_fns

        return batch_new


def load_dataset(args, split, logo=None):
    dataset = MemesDataset(args = args, root_folder=f'./resources/datasets', dataset=args.dataset, split=split,
                           image_size=args.image_size, fast=args.fast_process)

    return dataset

def load_logos_dataset(args, split, logos_dir):

    dataset = MemesLogoDataset(args = args, root_folder=f'./resources/datasets', root_folder_logos=logos_dir, dataset=args.dataset, split=split,
                           image_size=args.image_size, fast=args.fast_process, factor_shrink=args.factor_shrink, transparency=args.transparency,
                           logos_idx=args.logos_idx,
                           past_attack_file_locations=args.past_attack_file_locations,
                           past_attack_file = args.paste_attack_file)
    return dataset