import torch
import numpy as np
import torchvision.transforms as transforms
import os
import copy
from PIL import Image
import yaml
import sys
import json

from ..wrapper import BaseAttackWrapper
from ..utils.morph_utils import get_image_grid, add_signs_and_triggers
from tqdm import tqdm

class RMAMorphWrapper(BaseAttackWrapper):

    def __init__(self, dataset_wrapper, base_dir, config_path, trigger_position, class_ids, is_meta=False, p_ratio=1, bbox_current_format='xyxy', data_split='train'):
        self.class_ids = class_ids
        self.is_meta = is_meta
        self.trigger_position = trigger_position
        
        super().__init__(dataset_wrapper, base_dir, config_path, trigger_position, bbox_current_format=bbox_current_format, data_split=data_split)

    def __load_config__(self):
        
        # Load the config file as yaml
        with open(self.config_path, 'r') as f:
            config = yaml.safe_load(f)

        required_fields = [
            'target_class', 'trigger_path', 'sign_path', 'trigger_ratio', 
            'min_trigger_size', 'max_trigger_size', 'img_format', 'grid_size', 'prob_add_trigger'
        ]

        for field in required_fields:
            if field not in config:
                raise ValueError(f'Missing required config field: {field}')

        # Load the trigger
        trigger_path = config['trigger_path']
        trigger = Image.open(trigger_path).convert('RGB')
        trigger = transforms.ToTensor()(trigger)
        config['trigger'] = trigger
        config['trigger_position'] = self.trigger_position
        self.config = config
        self.target_id = config['target_class']

        # Check signs_paths exists
        if not os.path.exists(config['sign_path']):
            raise ValueError(f'Sign paths {config["sign_path"]} does not exist')
        
        # Check mapping.jsop exists inside sign_path
        mapping_path = os.path.join(config['sign_path'], 'mapping.json')
        if not os.path.exists(mapping_path):
            raise ValueError(f'Mapping file {mapping_path} does not exist in sign_path {config["sign_paths"]}')
        
        # Read the json file
        with open(mapping_path, 'r') as f:
            mapping = json.load(f)

        self.sign_mapping = mapping

        
    def __is_poisonable__(self, annotation, img_id, sub_id):

        # 1) Get the image
        img = self.dataset_wrapper.__get_image__(img_id, sub_id, get_bd=False)
        img = transforms.ToTensor()(img)

        # 3) Get the grid
        image_grid = get_image_grid(img, annotation, self.config['grid_size'], current_format=self.bbox_current_format)
        return image_grid
    
    def __poison_image__(self, img_id, sub_id, image_grid):

        try:
            annotation = self.dataset_wrapper.annotations[img_id][sub_id]
        except KeyError:
            raise KeyError(f'Image ID {img_id} with sub-image ID {sub_id} not found in annotations')
        
        # 1) Get the image
        img = self.dataset_wrapper.__get_image__(img_id, sub_id, get_bd=False)
        img = transforms.ToTensor()(img)

        # 2) Add signs and triggers
        annotation = copy.deepcopy(annotation)
        img, new_annotation = add_signs_and_triggers(
            self.config, img, image_grid, self.sign_mapping
        )

        # Add the new annotation to the list
        for ann in new_annotation:
            if self.is_meta:
                class_id = self.class_ids[ann['meta_label']]['id']
            else:
                class_id = self.class_ids[ann['class_label']]['id']

            annotation['bbox'].append(ann['bbox'])
            annotation['category_id'].append(class_id)
            
            if ann['poison_mask']:
                annotation['poison_mask'].append(True)
                annotation['target_id'].append(self.target_id)
            else:
                annotation['poison_mask'].append(False)
                annotation['target_id'].append(-1)  # -1 means no target class

        # Update the annotation with the new image
        self.dataset_wrapper.annotations[img_id][sub_id] = annotation

        # Save the poisoned image
        img_path = os.path.join(self.bd_save_dir, f'{img_id}_{sub_id}.{self.config["img_format"]}')
        img = transforms.ToPILImage()(img)
        img.save(img_path)

        self.dataset_wrapper.annotations[img_id][sub_id]['bd_img_path'] = img_path
    def __create_dataset_train__(self):
        
        annotations = self.dataset_wrapper.annotations
        image_ids = list(annotations.keys())
        poisonable_image_ids = []
        image_grids = {}

        print(f'[RMAMorphWrapper] Creating dataset for {self.data_split.value} (N={len(image_ids)})')

        for img_id in tqdm(image_ids, desc='Creating dataset'):

            annotation = annotations[img_id]
            
            if len(annotation) > 1:
                raise ValueError(f'Image has sub-images')
            
            annotation = annotation[0]
            image_grid = self.__is_poisonable__(annotation, img_id, 0)

            if sum(image_grid.flatten()) > 0:
                poisonable_image_ids.append(img_id)

            image_grids[img_id] = image_grid

        print(f'[RMAMorphWrapper] Found {len(poisonable_image_ids)} poisonable images')

        # 2) Poison images
        bd_id_set = []
        bd_id_list = []
        for img_id in tqdm(poisonable_image_ids, desc='Poisoning images'):
            sub_id = 0
            image_grid = image_grids[img_id]
            self.__poison_image__(img_id, sub_id, image_grid)
            bd_id_set.append(img_id)
            bd_id_list.append((img_id, sub_id))

        self.bd_id_set = set(bd_id_set)
        self.bd_id_list = bd_id_list
        
    def __create_dataset_test__(self):
        raise NotImplementedError('RMA Morph attack is not supported for test dataset. Please use the training dataset only.')

        