# scripts/data_setup.py

import os
import re
import random
import numpy as np
import torch
import open_clip
import webdataset as wds
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision import datasets
from collections import defaultdict
import scripts.config as config
import scripts.datasets_classes as ds_classes

# --- NEW IMPORTS FOR ALIGN FIX ---
from transformers import XLMRobertaTokenizer
from open_clip.tokenizer import HFTokenizer

# --- HELPER FOR ALIGN TEXT CLEANING ---
def whitespace_clean(text):
    text = re.sub(r'\s+', ' ', text)
    text = text.strip()
    return text

class HFBatchEncoding:
    def __init__(self, data): self.data = data
    def to(self, device):
        self.data = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in self.data.items()}
        return self.data 

def set_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

class RandomImageDataset(Dataset):
    def __init__(self, n_samples, size=224, n_classes=100):
        self.n_samples, self.shape, self.n_classes = n_samples, (3, size, size), n_classes
        self.labels = torch.randint(0, self.n_classes, (self.n_samples,))
    def __len__(self): return self.n_samples
    def __getitem__(self, idx): return torch.randn(self.shape), self.labels[idx]

def get_model_and_tokenizer():
    cfg = config.MODEL_CONFIG
    
    if cfg['type'] == 'open_clip':
        # 1. Create Model (Standard for ALL models)
        model, _, preprocess = open_clip.create_model_and_transforms(cfg['arch'], pretrained=cfg['data'])
        
        # 2. Tokenizer Logic
        # ALIGN uses RoBERTa. We must manually inject it to avoid Offline Mode crashes in transformers.
        if "roberta" in cfg['arch'].lower():
            print(f"INFO: Applying ALIGN/RoBERTa Manual Fix (Slow Tokenizer Injection).")
            
            # Dynamic detection: Base vs Large (based on your config change)
            hf_name = "xlm-roberta-base" if "base" in cfg['arch'].lower() else "xlm-roberta-large"
            
            # Check local folder first (e.g., ./xlm-roberta-base) if created by download script
            path_to_load = f"./{hf_name}" if os.path.exists(f"./{hf_name}") else hf_name
            
            try:
                # Use SLOW tokenizer to bypass Rust/Config mismatch bugs
                internal_tokenizer = XLMRobertaTokenizer.from_pretrained(path_to_load)
                
                # Manually construct OpenCLIP wrapper to bypass its __init__ (which crashes)
                tokenizer = HFTokenizer.__new__(HFTokenizer)
                
                # --- MANUALLY INJECT ALL REQUIRED ATTRIBUTES ---
                tokenizer.tokenizer = internal_tokenizer
                tokenizer.context_length = 77
                tokenizer.clean_fn = whitespace_clean
                tokenizer.tokenizer_mode = 'slow' 
                tokenizer.strip_sep_token = False
                
            except Exception as e:
                print(f"WARNING: Manual ALIGN fix failed: {e}. Falling back to standard method.")
                tokenizer = open_clip.get_tokenizer(cfg['arch'])
                
        else:
            # Standard path for SigLIP, CLIP, CoCa, ConvNeXt, etc.
            tokenizer = open_clip.get_tokenizer(cfg['arch'])
            
        return model.to(config.TARGET_DEVICE).eval(), tokenizer, preprocess
        
    return None, None, None

def create_stratified_subset(dataset, ratio):
    if ratio >= 1.0: return dataset
    if hasattr(dataset, 'targets'): targets = dataset.targets
    elif hasattr(dataset, 'labels'): targets = dataset.labels
    else: targets = [y for _, y in dataset]
    if isinstance(targets, torch.Tensor): targets = targets.tolist()
    class_indices = defaultdict(list)
    for idx, label in enumerate(targets): class_indices[label].append(idx)
    selected_indices = []
    for label in sorted(class_indices.keys()):
        indices = class_indices[label]
        n_keep = max(1, int(round(len(indices) * ratio)))
        selected_indices.extend(random.sample(indices, min(n_keep, len(indices))))
    return Subset(dataset, selected_indices)

def _get_dataset_instance(name, root, split_mode, transform):
    name = name.lower()
    
    if name == 'cifar10': return datasets.CIFAR10(root=root, train=(split_mode=='train'), transform=transform, download=False)
    elif name == 'cifar100': return datasets.CIFAR100(root=root, train=(split_mode=='train'), transform=transform, download=False)
    elif name == 'dtd':
        split = 'train' if split_mode == 'train' else 'test'
        return datasets.DTD(root=root, split=split, transform=transform, download=False)
    
    target_path = root
    sub_folder = 'train' if split_mode == 'train' else 'val'
    if os.path.exists(os.path.join(root, sub_folder)): target_path = os.path.join(root, sub_folder)
    
    if not os.path.exists(target_path):
        raise FileNotFoundError(f"Path not found: {target_path}")
        
    return datasets.ImageFolder(root=target_path, transform=transform)

# --- Remapped Subset for Correct Eval ---
class RemappedSubset(Dataset):
    """
    A Subset that remaps labels from [Original_Index] to [0, ..., N-1].
    Required for accurate small-subset evaluation.
    """
    def __init__(self, dataset, indices, label_map):
        self.dataset = dataset
        self.indices = indices
        self.label_map = label_map # Dict: {Original_Label_Idx: New_Label_Idx}

    def __getitem__(self, idx):
        x, y = self.dataset[self.indices[idx]]
        return x, self.label_map[y]

    def __len__(self):
        return len(self.indices)

def _filter_dataset(dataset, all_class_names, target_subset_name):
    """
    Filters dataset and REMAPS labels to 0..N-1.
    Returns: (RemappedSubset, list_of_subset_class_names)
    """
    valid_class_indices = [] # Keep order if provided by list
    
    # 1. ImageNet-500: Load indices from id_data.txt
    if target_subset_name == 'imagenet500id':
        if not hasattr(config, 'ID_DATA_PATH') or not os.path.exists(config.ID_DATA_PATH):
            print(f"  [Error] {config.ID_DATA_PATH} not found for ImageNet-500 filtering.")
            return None, None
            
        with open(config.ID_DATA_PATH, 'r') as f:
            # We sort to ensure deterministic label mapping 0..499
            valid_class_indices = sorted(list(set(int(line.strip()) for line in f if line.strip())))
        print(f"  -> Filter Mode: Indices. Loaded {len(valid_class_indices)} target classes from file.")

    # 2. ImageNet-10 / 20: Filter by Class Name matching
    elif target_subset_name in ['imagenet10', 'imagenet20']:
        target_names_str = ds_classes.DATASET_METADATA.get(target_subset_name, {}).get("classes", [])
        if not target_names_str: return None, None
        
        # We need to maintain the order of target_names_str for the new class list
        # valid_class_indices will store the original ImageNet index for each target name
        
        # Build lookup: clean_name -> original_index
        name_to_idx = {}
        for idx, full_name in enumerate(all_class_names):
             # full_name is from IMAGENET1K_CLASSES, which are already clean strings
             name_to_idx[full_name.lower().strip()] = idx

        found_count = 0
        for t_name in target_names_str:
            clean_t = t_name.lower().strip()
            if clean_t in name_to_idx:
                valid_class_indices.append(name_to_idx[clean_t])
                found_count += 1
            else:
                print(f"  [Warning] Target class '{t_name}' not found in source dataset.")
                         
        print(f"  -> Filter Mode: Exact Names. Matched {found_count}/{len(target_names_str)} target classes.")

    else:
        return dataset, all_class_names

    if not valid_class_indices:
        return dataset, all_class_names

    # Create mapping: Old_Index -> New_Index (0..N-1)
    label_map = {old_idx: new_idx for new_idx, old_idx in enumerate(valid_class_indices)}
    
    # Filter Samples
    valid_set_indices = set(valid_class_indices)
    
    if hasattr(dataset, 'targets'):
        indices_to_keep = [i for i, label in enumerate(dataset.targets) if label in valid_set_indices]
    elif hasattr(dataset, 'samples'):
        indices_to_keep = [i for i, (_, label) in enumerate(dataset.samples) if label in valid_set_indices]
    else:
        return None, None

    print(f"  -> Subset created: {len(indices_to_keep)} samples kept. Labels remapped to 0-{len(valid_class_indices)-1}.")
    
    # Create the New Class List (Subset)
    subset_class_names = [all_class_names[i] for i in valid_class_indices]
    
    return RemappedSubset(dataset, indices_to_keep, label_map), subset_class_names

def get_dataset_loaders(dataset_name, preprocess, get_train=True):
    dataset_name = dataset_name.lower()
    
    if dataset_name in ['imagenet10', 'imagenet20', 'imagenet500id']:
        root_path = config.DATASET_PATHS.get("imagenet1kval", None)
        loader_key = "imagenet1kval"
    else:
        root_path = config.DATASET_PATHS.get(dataset_name, None)
        loader_key = dataset_name

    if root_path is None: return None, None, None, None

    metadata = ds_classes.DATASET_METADATA.get(dataset_name, {})
    template = metadata.get("template", ds_classes.SIMPLE_TEMPLATE)
    defined_class_names = metadata.get("classes", None) 

    try:
        test_dset = _get_dataset_instance(loader_key, root_path, split_mode='test', transform=preprocess)
    except Exception as e:
        print(f"Skipping {dataset_name}: {e}")
        return None, None, None, None

    final_class_names = None
    
    # Handle WNIDs vs Human Names
    is_wnid = False
    if hasattr(test_dset, 'classes') and len(test_dset.classes) > 0:
        sample_class = str(test_dset.classes[0])
        if sample_class.startswith('n') and any(char.isdigit() for char in sample_class):
            is_wnid = True
    
    if is_wnid:
        final_class_names = ds_classes.IMAGENET1K_CLASSES
    elif hasattr(test_dset, 'classes'):
        final_class_names = [str(c).replace('_', ' ').split(',')[0] for c in test_dset.classes]
    elif defined_class_names:
        final_class_names = defined_class_names
    
    if not final_class_names: return None, None, None, None

    # Apply Subset Filtering
    if dataset_name in ['imagenet10', 'imagenet20', 'imagenet500id']:
        test_dset, final_class_names = _filter_dataset(test_dset, final_class_names, dataset_name)
        if test_dset is None: return None, None, None, None

    if config.EVAL_DATASET_SUBSAMPLE_RATIO < 1.0:
        test_dset = create_stratified_subset(test_dset, config.EVAL_DATASET_SUBSAMPLE_RATIO)

    test_loader = DataLoader(test_dset, batch_size=config.EVAL_BATCH_SIZE, shuffle=False, num_workers=8, persistent_workers=False, prefetch_factor=2)

    cal_loader, train_loader = None, None
    if get_train:
        try:
            train_dset = _get_dataset_instance(loader_key, root_path, split_mode='train', transform=preprocess)
            
            if dataset_name in ['imagenet10', 'imagenet20', 'imagenet500id']:
                 train_dset, _ = _filter_dataset(train_dset, final_class_names, dataset_name)
            
            if train_dset is not None:
                indices = list(range(len(train_dset)))
                random.shuffle(indices)
                n_cal = config.NUM_CALIBRATION_SAMPLES_REAL
                n_trn = config.NUM_TRAINING_SAMPLES_REAL
                
                if len(indices) < n_cal + n_trn:
                    mid = len(indices) // 2
                    cal_idx, trn_idx = indices[:mid], indices[mid:]
                else:
                    cal_idx, trn_idx = indices[:n_cal], indices[n_cal:n_cal+n_trn]
                    
                cal_loader = DataLoader(Subset(train_dset, cal_idx), batch_size=config.BATCH_SIZE, shuffle=True)
                train_loader = DataLoader(Subset(train_dset, trn_idx), batch_size=config.BATCH_SIZE, shuffle=True)
        except Exception as e:
            pass 

    return (cal_loader, train_loader), test_loader, final_class_names, template

def create_train_iterable(shard_pattern, map_fn, batch_size):
    if not os.path.exists(os.path.dirname(str(shard_pattern).split('{')[0])): return None
    def preprocess_webdataset(sample): return map_fn(sample[0]), sample[1]
    
    return wds.WebDataset(shard_pattern, resampled=True, shardshuffle=True).decode("pil").to_tuple("jpg", "txt").map(preprocess_webdataset).shuffle(1000).batched(batch_size)