import torch
import torch.nn.functional as F
from torchvision import transforms, datasets
from torchvision.transforms.autoaugment import AutoAugmentPolicy
from torchvision.transforms import InterpolationMode
from torch.utils.data import random_split, ConcatDataset, Dataset, Subset, DataLoader
import numpy as np
import random
import os
from utils.config import *
from utils.utils import load_model
import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler

class ImageNet100(Dataset):
    def __init__(self, root, split='train', transform=None, class_list_path=None):
        self.root = root
        self.split = split
        self.transform = transform
        
        if class_list_path is None:
            class_list_path = '/proj/mars_yizhou/users/x_antad/data/imagenet100/imagenet100.txt'
        
        with open(class_list_path, 'r') as f:
            self.selected_classes = [line.strip() for line in f.readlines()]
        
        full_dataset = datasets.ImageFolder(os.path.join(root, split))
        
        self.class_to_idx_full = full_dataset.class_to_idx
        
        self.class_to_idx = {cls: i for i, cls in enumerate(self.selected_classes)}
        self.idx_to_class = {i: cls for cls, i in self.class_to_idx.items()}
        
        self.samples = []
        self.targets = []
        
        for path, label in full_dataset.samples:
            class_name = full_dataset.classes[label]
            if class_name in self.selected_classes:
                new_label = self.class_to_idx[class_name]
                self.samples.append((path, new_label))
                self.targets.append(new_label)
            
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        path, target = self.samples[idx]
        sample = datasets.folder.default_loader(path)
        
        if self.transform is not None:
            sample = self.transform(sample)
        
        return sample, target


class BufferedDataset(Dataset):
    def __init__(self, raw_dataset, transform):
        self.dataset = raw_dataset
        self.transform = transform
        self.targets = extract_targets(self.dataset)
        self.num_classes = len(set(self.targets))
        self.targets_tensor = torch.tensor(self.targets, dtype=torch.long)
        self.buffered_targets = F.one_hot(self.targets_tensor, num_classes=self.num_classes).float()
        self.img_size = self.transform(self.dataset[0][0]).shape[-2:]
        self.effective_counts = torch.full((len(self),), 1.0)

    def __getitem__(self, index):
        data, _ = self.dataset[index]
        data = self.transform(data)
        target = self.targets_tensor[index]
        buffered_target = self.buffered_targets[index]
        return data, buffered_target, target, index

    def __len__(self):
        return len(self.dataset)


class MultiViewBufferedDataset(Dataset):
    def __init__(self, raw_dataset, simple_transform, strong_transform, K=2):
        self.dataset = raw_dataset
        self.simple_transform = simple_transform
        self.strong_transform = strong_transform
        self.K = K
        self.targets = extract_targets(self.dataset)        
        self.num_classes = len(set(self.targets))
        self.targets_tensor = torch.tensor(self.targets, dtype=torch.long)
        self.buffered_targets = F.one_hot(self.targets_tensor, num_classes=self.num_classes).float()
        self.img_size = self.strong_transform(self.dataset[0][0]).shape[-2:]
        self.effective_counts = torch.full((len(self),), 1.0)

    def __getitem__(self, index):
        img, _ = self.dataset[index]
        view_simple = self.simple_transform(img)
        views_strong = []
        for k in range(self.K):
            views_strong.append(self.strong_transform(img))
        soft_target = self.buffered_targets[index]
        return view_simple, views_strong, soft_target, index

    def __len__(self):
        return len(self.dataset)

def extract_targets(ds):
    if isinstance(ds, Subset):
        base_targets = extract_targets(ds.dataset)
        return [base_targets[i] for i in ds.indices]

    if isinstance(ds, ConcatDataset):
        all_t = []
        for child in ds.datasets:
            all_t.extend(extract_targets(child))
        return all_t

    if hasattr(ds, 'targets'):
        return list(ds.targets)

    if hasattr(ds, 'samples'):
        return [s[1] for s in ds.samples]

    raise ValueError(f"Can't extract targets from {type(ds)}")


class SubsetWithTransform(Dataset):
    def __init__(self, subset, transform):
        self.subset   = subset
        self.transform = transform

    def __len__(self):
        return len(self.subset)

    def __getitem__(self, idx):
        img, label = self.subset[idx]         
        img = self.transform(img)
        return img, label


def load_datasets(args):
    dataset = args.dataset.lower()
    cfg,  is_vit = DATA_CONFIG[args.dataset.lower()], "vit" in args.model.lower()

    if dataset not in ['imagenet1k', 'imagenet100']:
        simple_transform = transforms.Compose([
                *([transforms.Resize((224, 224)),
                    transforms.RandomCrop(224, padding=28)] if is_vit else [
                    transforms.RandomCrop(cfg['size'], padding=cfg['padding'])]),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
                *([transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])] if is_vit else [
                    transforms.Normalize(mean=cfg['mean'], std=cfg['std'])])
        ])

        strong_transforms_list = simple_transform.transforms.copy()
        strong_transforms_list.insert(2, transforms.AutoAugment(policy=transforms.autoaugment.AutoAugmentPolicy.CIFAR10, interpolation=transforms.InterpolationMode.BILINEAR))
        strong_transforms_list.insert(-1, transforms.RandomErasing(p=0.1))
        strong_transform = transforms.Compose(strong_transforms_list)

        val_transform = transforms.Compose([
            *([transforms.Resize((224, 224))] if is_vit else []),
            transforms.ToTensor(),
            *([transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])] if is_vit else [
                transforms.Normalize(mean=cfg['mean'], std=cfg['std'])])
        ])
        
    else:
        simple_transform = transforms.Compose([
            #transforms.Resize(232, interpolation=InterpolationMode.BILINEAR),
            #transforms.CenterCrop(224),
            #transforms.RandomResizedCrop(176),
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=cfg['mean'], std=cfg['std'])
        ])

        strong_transform = transforms.Compose([
            transforms.RandomResizedCrop(176),
            transforms.RandomHorizontalFlip(),
            transforms.TrivialAugmentWide(),
            transforms.ToTensor(),
            transforms.RandomErasing(p=0.1),
            transforms.Normalize(mean=cfg['mean'], std=cfg['std'])
        ])

        val_transform = transforms.Compose([
            #transforms.Resize(232, interpolation=InterpolationMode.BILINEAR),
            transforms.Resize(256, interpolation=InterpolationMode.BILINEAR),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=cfg['mean'], std=cfg['std'])
        ])

    if "cifar" in dataset:
        train_dataset, val_dataset = cfg['cls'](root=cfg['path'], train=True, download=True), None

        if hasattr(args, 'val') and args.val:
            train_dataset, val_dataset = random_split(
                train_dataset, [len(train_dataset) - 5000, 5000],
                generator=torch.Generator().manual_seed(42)
            )
            val_dataset = SubsetWithTransform(val_dataset, val_transform)

        test_dataset = cfg['cls'](root=cfg['path'], train=False, download=True, transform=val_transform)        
    
    elif dataset == 'imagenet100':
        train_dataset = ImageNet100(root=cfg['path'], split='train', transform=None, 
                                     class_list_path=cfg.get('class_list_path'))
        val_dataset = None
        
        if hasattr(args, 'val') and args.val:
            train_dataset_full = ImageNet100(root=cfg['path'], split='train', transform=None,
                                              class_list_path=cfg.get('class_list_path'))
            train_dataset, val_dataset = random_split(
                train_dataset_full, [len(train_dataset_full) - 10000, 10000],
                generator=torch.Generator().manual_seed(42)
            )
            val_dataset = SubsetWithTransform(val_dataset, val_transform)

        test_dataset = ImageNet100(root=cfg['path'], split='val', transform=val_transform,
                                    class_list_path=cfg.get('class_list_path'))
       
    else:
        train_dataset, val_dataset = cfg['cls'](root=cfg['path']+'/train', transform=None), None

        if hasattr(args, 'val') and args.val:
            train_dataset, val_dataset = random_split(
                train_dataset, [len(train_dataset) - 10000, 10000],
                generator=torch.Generator().manual_seed(42)
            )
            val_dataset = SubsetWithTransform(val_dataset, val_transform)

        test_dataset = cfg['cls'](root=cfg['path']+'/val', transform=val_transform)        

    if args.method.lower() == 'bsd+':
        train_dataset = MultiViewBufferedDataset(raw_dataset=train_dataset, simple_transform=simple_transform, strong_transform=strong_transform)
    else:
        train_dataset = BufferedDataset(raw_dataset=train_dataset, transform=simple_transform)
    
    if hasattr(args, 'noise_rate') and args.noise_rate > 0.0:
        train_dataset = add_noise(args, train_dataset)

    if args.method.lower() == 'kd' and hasattr(args, 'teacher_model_paths') and len(args.teacher_model_paths) > 0:
        train_dataset = assign_kd_targets(args, train_dataset)

    return train_dataset, val_dataset, test_dataset


def add_noise(args, train_dataset):
    if hasattr(args, 'noise_rate') and args.noise_rate > 0.0:
        if not args.distributed or local_rank == 0:
            print(f"Introducing label noise with rate {args.noise_rate}, type {args.noise_type}, using noise_seed {args.noise_seed}")

        noise_np_rng = np.random.RandomState(args.noise_seed) 
        noise_py_random = random.Random(args.noise_seed)   

        num_samples_total = len(train_dataset.targets_tensor)
        
        if args.noise_type == "sym":
            num_noisy = int(args.noise_rate * num_samples_total)
            noisy_indices = noise_np_rng.choice(num_samples_total, num_noisy, replace=False)
            for idx in noisy_indices:
                noisy_label = noise_py_random.randint(0, train_dataset.num_classes - 1)
                train_dataset.targets_tensor[idx] = noisy_label
                if hasattr(train_dataset, 'buffered_targets'):
                    train_dataset.buffered_targets[idx] = F.one_hot(torch.tensor(noisy_label), num_classes=train_dataset.num_classes).float()
            if not args.distributed or local_rank == 0:
                print(f"Symmetric noise: Introduced noise in {len(noisy_indices)} samples.")
        
        elif args.noise_type == "asym":
            if args.dataset.lower() == 'cifar10':
                noise_mapping = { 9: 1, 1: 9, 2: 0, 0: 2, 4: 7, 7: 4, 3: 5, 5: 3 } if args.dataset.lower() == 'cifar10' else {}
            elif args.dataset.lower() == 'cifar100':
                noise_mapping = {}
                for superclass_start in range(0, 100, 5):
                    group = list(range(superclass_start, superclass_start + 5))
                    for i in range(5):
                        noise_mapping[group[i]] = group[(i + 1) % 5]

            num_actually_flipped = 0
            for idx in range(num_samples_total):
                original_label = train_dataset.targets_tensor[idx].item()
                if original_label in noise_mapping:
                    if noise_py_random.random() < args.noise_rate:
                        flipped_label = noise_mapping[original_label]
                        train_dataset.targets_tensor[idx] = flipped_label
                        if hasattr(train_dataset, 'buffered_targets'):
                            train_dataset.buffered_targets[idx] = F.one_hot(torch.tensor(flipped_label), num_classes=train_dataset.num_classes).float()
                        num_actually_flipped += 1
            if not args.distributed or local_rank == 0:
                 print(f"Asymmetric noise: Flipped {num_actually_flipped} labels for relevant classes.")
    return train_dataset

def assign_kd_targets(args, train_dataset):
    teacher_model_paths = args.teacher_model_paths
    num_teachers = len(teacher_model_paths)
    
    # Determine device
    if torch.cuda.is_available():
        device = torch.device(f"cuda:{torch.cuda.current_device()}")
    else:
        device = torch.device("cpu")

    is_distributed = torch.distributed.is_initialized()
    rank = torch.distributed.get_rank() if is_distributed else 0

    if rank == 0:
        print(f"Processing predictions from {num_teachers} teacher model(s) distributed...")

    # 1. Setup Distributed Sampler
    # We assume the dataset order is fixed. shuffle=False is crucial here so we can 
    # map indices back to the dataset consistently.
    if is_distributed:
        sampler = DistributedSampler(train_dataset, shuffle=False)
    else:
        sampler = None

    teacher_loader = DataLoader(
        train_dataset, 
        batch_size=args.batch_size, 
        shuffle=False, 
        sampler=sampler,
        num_workers=args.num_workers, 
        pin_memory=True
    )
    
    # 2. Initialize a buffer for the ENTIRE dataset on the GPU
    # We will scatter-add predictions into this tensor.
    # Shape: [Total Samples, Num Classes]
    total_samples = len(train_dataset)
    num_classes = train_dataset.num_classes
    accumulated_predictions = torch.zeros((total_samples, num_classes), device=device, dtype=torch.float)

    # 3. Iterate over Teacher Models
    for path in teacher_model_paths:
        if rank == 0:
            print(f"Loading teacher model from {path}")
            
        # Load model
        teacher_model_instance = load_model(args.teacher_model, num_classes)
        # Ensure map_location loads to the correct local device
        checkpoint = torch.load(path, map_location=device)

        state_dict = checkpoint['state_dict'] if 'state_dict' in checkpoint else checkpoint
        if any(k.startswith('_orig_mod.') for k in state_dict.keys()):
            state_dict = {k.replace('_orig_mod.', ''): v for k, v in state_dict.items()}
        if any(k.startswith('module.') for k in state_dict.keys()):
            state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}

        teacher_model_instance.load_state_dict(state_dict)
        teacher_model_instance.to(device)
        teacher_model_instance.eval()
        
        # 4. Inference Loop
        # Each rank processes only its subset of data (defined by the sampler)
        with torch.no_grad():
            for batch in teacher_loader:
                data, _, _, indices = batch
                data = data.to(device)
                indices = indices.to(device)

                outputs = teacher_model_instance(data)
                teacher_probs = F.softmax(outputs / args.temp, dim=1)
                
                # Place predictions into the global buffer at the correct indices
                # In DDP, other indices remain 0.0 on this rank
                accumulated_predictions[indices] += teacher_probs

        # Cleanup to save VRAM
        del teacher_model_instance
        torch.cuda.empty_cache()
    
    # 5. Average over teachers
    # (We accumulated sums, now divide by N teachers)
    if num_teachers > 1:
        accumulated_predictions /= num_teachers

    # 6. Synchronize across GPUs
    if is_distributed:
        if rank == 0:
            print("Synchronizing teacher predictions across GPUs...")
        # Summing zeros (from other ranks) and values (from this rank) works perfectly 
        # to reconstruct the full dataset.
        dist.all_reduce(accumulated_predictions, op=dist.ReduceOp.SUM)

    # 7. Assign to dataset
    # Move to CPU if your dataset stores targets on CPU to save GPU memory, 
    # or keep on GPU if you have plenty of VRAM.
    train_dataset.buffered_targets = accumulated_predictions.cpu()
    
    if rank == 0:
        print("Ensemble teacher predictions have been set as buffered_targets.")

    return train_dataset
