import torch
from torch.utils.data import Dataset
from torchvision import transforms
import random
from PIL import Image
import numpy as np
from tqdm import tqdm
try:
    from Data.corruptions import apply_corruption
except ImportError:
    from corruptions import apply_corruption

class CorruptedDataset(Dataset):
    def __init__(self, base_dataset, corruption_prob=0.3, corruption_types=None, 
                severity_range=(1, 3), balance_by_class=True):
        self.base_dataset = base_dataset
        self.corruption_prob = corruption_prob
        self.corruption_types = corruption_types
        self.severity_range = severity_range
        self.balance_by_class = balance_by_class

        if self.balance_by_class:
            self._init_class_balance()
        else:
            self._init_global_corruption()
         
        if corruption_types is None:   
            self.corruption_types = [ 'jpeg_compression'] 


    def _init_class_balance(self):
        self.class_counts = {}
        self.class_indices = {}
        dataset_size = len(self.base_dataset)

        for idx in tqdm(range(dataset_size), total=dataset_size, bar_format="{l_bar}{bar:30}{r_bar}"):
        # for i, idx in enumerate(self.indices):
            if isinstance(self.base_dataset[idx], tuple):
                if len(self.base_dataset[idx]) == 2:
                    _, label = self.base_dataset[idx]
                else:   
                    _, label, _ = self.base_dataset[idx]
            else:
                self.balance_by_class = False
                self._init_global_corruption()
                return
            
            if label not in self.class_counts:
                self.class_counts[label] = 0
                self.class_indices[label] = []
            
            self.class_counts[label] += 1
            self.class_indices[label].append(idx)  
        
        self._select_corruption_samples()

    def _init_global_corruption(self):
        self.corrupted_indices = set()
        self._select_global_corruption_samples()

    def _select_corruption_samples(self):
        self.corrupted_indices = set()
        
        for label, indices in self.class_indices.items():
            num_to_corrupt = int(len(indices) * self.corruption_prob)
            
            if num_to_corrupt > 0:
                corrupted_for_class = random.sample(indices, num_to_corrupt)
                self.corrupted_indices.update(corrupted_for_class)

    def _select_global_corruption_samples(self):
        total_samples = len(self.base_dataset)
        num_to_corrupt = int(total_samples * self.corruption_prob)
        
        if num_to_corrupt > 0:
            all_indices = list(range(total_samples))
            corrupted_indices = random.sample(all_indices, num_to_corrupt)
            self.corrupted_indices = set(corrupted_indices)

    def reset_corruption_selection(self):
        if self.balance_by_class:
            self._select_corruption_samples()
        else:
            self._select_global_corruption_samples()

    def __len__(self):
        return len(self.base_dataset)  
    
    def __getitem__(self, idx):
        if isinstance(self.base_dataset[idx], tuple):
            if len(self.base_dataset[idx]) == 2:
                image, label = self.base_dataset[idx]
            else:
                image, label, _ = self.base_dataset[idx]
        else:
            image = self.base_dataset[idx]
            label = None
        
        is_corrupted = idx in self.corrupted_indices
        
        if is_corrupted:
            corruption_type = 'jpeg_compression'  
            severity = random.randint(self.severity_range[0], self.severity_range[1])
            
            if isinstance(image, torch.Tensor):
                if image.dim() == 3:  
                    image_np = (image.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
                else:  
                    image_np = (image.numpy() * 255).astype(np.uint8)
                image_pil = Image.fromarray(image_np)
            else:
                image_pil = image
            
            try:
                corrupted_image = apply_corruption(image_pil, corruption_type, severity)
                
                original_is_tensor = isinstance(self.base_dataset[idx][0] if label is not None else self.base_dataset[idx], torch.Tensor)
                if original_is_tensor:
                    if isinstance(corrupted_image, Image.Image):
                        corrupted_array = np.array(corrupted_image) / 255.0
                    else:
                        corrupted_array = corrupted_image / 255.0
                    
                    if corrupted_array.ndim == 3:
                        image = torch.from_numpy(corrupted_array).permute(2, 0, 1).float()
                    else:
                        image = torch.from_numpy(corrupted_array).float()
                else:
                    image = corrupted_image
                    
            except Exception as e:
                print(f"Warning: Failed to apply corruption {corruption_type}: {e}")
                is_corrupted = False
        
        if label is not None:
            return image, label, is_corrupted
        else:
            return image, is_corrupted



class AugmentedCorruptedDataset(Dataset):
    def __init__(self, base_dataset, corruption_prob=0.05, corruption_types=None, 
                severity_range=(1, 3), balance_by_class=True):
        self.base_dataset = base_dataset
        self.corruption_prob = corruption_prob
        self.corruption_types = corruption_types
        self.severity_range = severity_range
        self.balance_by_class = balance_by_class

        self.original_size = len(base_dataset)
        
        if corruption_types is None:
            self.corruption_types = ['jpeg_compression']
        
        self._generate_corruption_mapping()

        self.total_size = self.original_size + len(self.corruption_mapping)
    
    def _generate_corruption_mapping(self):
        self.corruption_mapping = {}  # {corrupted_index: (original_index, corruption_type, severity)}
        
        if self.balance_by_class:
            self._generate_class_balanced_mapping()
        else:
            self._generate_global_mapping()
    
    def _generate_class_balanced_mapping(self):
        class_indices = {}
        
        for idx in range(self.original_size):
            if isinstance(self.base_dataset[idx], tuple):
                _, label = self.base_dataset[idx]
            else:
                self.balance_by_class = False
                self._generate_global_mapping()
                return
            
            if label not in class_indices:
                class_indices[label] = []
            class_indices[label].append(idx)
        
        corrupted_idx = self.original_size
        for label, indices in class_indices.items():
            num_to_corrupt = int(len(indices) * self.corruption_prob)
            
            if num_to_corrupt > 0:
                selected_indices = random.sample(indices, num_to_corrupt)
                
                for original_idx in selected_indices:
                    corruption_type = random.choice(self.corruption_types)
                    severity = random.randint(self.severity_range[0], self.severity_range[1])
                    
                    self.corruption_mapping[corrupted_idx] = (original_idx, corruption_type, severity)
                    corrupted_idx += 1
    
    def _generate_global_mapping(self):
        num_to_corrupt = int(self.original_size * self.corruption_prob)
        
        if num_to_corrupt > 0:
            selected_indices = random.sample(range(self.original_size), num_to_corrupt)
            
            corrupted_idx = self.original_size
            for original_idx in selected_indices:
                corruption_type = random.choice(self.corruption_types)
                severity = random.randint(self.severity_range[0], self.severity_range[1])
                
                self.corruption_mapping[corrupted_idx] = (original_idx, corruption_type, severity)
                corrupted_idx += 1
    
    def regenerate_corruptions(self):
        self.corruption_mapping = {}
        self._generate_corruption_mapping()
        self.total_size = self.original_size + len(self.corruption_mapping)
    
    def __len__(self):
        return self.total_size
    
    def __getitem__(self, idx):
        if idx < self.original_size:
            if isinstance(self.base_dataset[idx], tuple):
                image, label = self.base_dataset[idx]
                return image, label, False  
            else:
                image = self.base_dataset[idx]
                return image, False
        
        if idx not in self.corruption_mapping:
            raise IndexError(f"Index {idx} out of range")
        
        original_idx, corruption_type, severity = self.corruption_mapping[idx]
        
        if isinstance(self.base_dataset[original_idx], tuple):
            image, label = self.base_dataset[original_idx]
        else:
            image = self.base_dataset[original_idx]
            label = None
        
        try:
            if isinstance(image, torch.Tensor):
                if image.dim() == 3:  
                    image_np = (image.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
                else: 
                    image_np = (image.numpy() * 255).astype(np.uint8)
                image_pil = Image.fromarray(image_np)
            else:
                image_pil = image
            
            corrupted_image = apply_corruption(image_pil, corruption_type, severity)
            
            original_is_tensor = isinstance(self.base_dataset[original_idx][0] if label is not None else self.base_dataset[original_idx], torch.Tensor)
            if original_is_tensor:
                if isinstance(corrupted_image, Image.Image):
                    corrupted_array = np.array(corrupted_image) / 255.0
                else:
                    corrupted_array = corrupted_image / 255.0
                
                if corrupted_array.ndim == 3:
                    image = torch.from_numpy(corrupted_array).permute(2, 0, 1).float()
                else:
                    image = torch.from_numpy(corrupted_array).float()
            else:
                image = corrupted_image
                
        except Exception as e:
            print(f"Warning: Failed to apply corruption {corruption_type}: {e}")
            pass
        
        if label is not None:
            return image, label, True  
        else:
            return image, True
    
    def get_corruption_stats(self):
        original_size = self.original_size
        corrupted_size = len(self.corruption_mapping)
        total_size = self.total_size
        
        return {
            'original_samples': original_size,
            'corrupted_samples': corrupted_size,
            'total_samples': total_size,
            'corruption_rate': corrupted_size / original_size if original_size > 0 else 0,
            'augmentation_ratio': total_size / original_size if original_size > 0 else 1,
            'balance_by_class': self.balance_by_class
        }
