import numpy as np
import torch
from torch.utils.data import Dataset, Subset
from PIL import Image
import os, glob
import torch
import pandas as pd
import random

class HAPTDataset(Dataset):
    def __init__(self, X_file, y_file, num_clients=8, feature_indices=None):
        # Load data from the provided files
        X = np.loadtxt(X_file).astype(float)
        y = np.loadtxt(y_file).astype(float)

        # Shuffle the samples along axis 0 (samples dimension)
        sample_indices = np.random.permutation(X.shape[0])
        X, y = X[sample_indices], y[sample_indices]

        # Adjust the number of features based on the specified number of clients
        if num_clients is not None:
            num_features = X.shape[1]
            new_num_features = num_features - (num_features % num_clients)
            X = X[:, :new_num_features]

        if feature_indices:
            X = (X-feature_indices[0]) / feature_indices[1]
        else:
            self.mean = X.mean(axis=0)
            self.std = X.std(axis=0) + 1e-8
            X = (X-self.mean) / self.std
        
        # Adjust labels to be zero-based by subtracting 1
        y = y - 1

        # Convert to PyTorch tensors
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.long)

        self.num_clients = num_clients
        self.partial_dim = new_num_features // num_clients
        # Set up class names if needed
        self.classes = [f"{i}." for i in range(1, 13)]

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        x_row = self.X[idx]
        x_multi = x_row.view(self.num_clients, self.partial_dim)
        
        return x_multi, self.y[idx]

class IsoletDataset(Dataset):
    def __init__(self, data_file, num_clients=8, feature_indices=None):
        df = pd.read_csv(
            data_file,
            header=None,
            sep=r'[,\s]+',   
            engine='python'
        )
        data = df.values
        X = data[:, :-1].astype(np.float32)  # (N, 617)
        y = data[:,  -1].astype(int) - 1     # 0~25

        perm = np.random.permutation(len(y))
        X, y = X[perm], y[perm]

        if num_clients is not None:
            num_features = X.shape[1]
            new_num_features = num_features - (num_features % num_clients)
            X = X[:, :new_num_features]

        if feature_indices is not None:
            X = (X-feature_indices[0]) / feature_indices[1]
        else:
            self.mean = X.mean(axis=0)
            self.std = X.std(axis=0) + 1e-8
            X = (X-self.mean) / self.std

        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.long)

        self.num_clients = num_clients
        self.partial_dim = self.X.shape[1] // num_clients  # 617 // 8 = 77

        self.classes = [chr(ord('A') + i) for i in range(26)]

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        x = self.X[idx]
        x_multi = x.view(self.num_clients, self.partial_dim)
        return x_multi, self.y[idx]


class ModelNet10MultiViewDataset(Dataset):
    def __init__(self, root_dir, split='train', transform=None, num_repeat=4):
        self.root_dir = root_dir
        self.split = split
        self.transform = transform
        self.num_repeat = num_repeat

        self.class_names = []
        self.class_to_idx = {}
        for dname in sorted(os.listdir(root_dir)):
            fullpath = os.path.join(root_dir, dname)
            if os.path.isdir(fullpath):
                self.class_names.append(dname)
        self.class_names.sort()
        for i, cname in enumerate(self.class_names):
            self.class_to_idx[cname] = i

        self.objects_per_class = {}
        for cls_idx, cname in enumerate(self.class_names):
            class_folder = os.path.join(self.root_dir, cname, split)
            if not os.path.exists(class_folder):
                self.objects_per_class[cls_idx] = []
                continue

            all_imgs = glob.glob(os.path.join(class_folder, '*.png'))
            temp_dict = {}
            for img_path in all_imgs:
                fname = os.path.basename(img_path)
                prefix = fname.rsplit('_v',1)[0]  # "chair_0001"
                if prefix not in temp_dict:
                    temp_dict[prefix] = []
                temp_dict[prefix].append(img_path)

            obj_list = []
            for obj_id, paths in temp_dict.items():
                paths.sort()  # v001..v012
                obj_list.append( (obj_id, paths) )
            self.objects_per_class[cls_idx] = obj_list

        self.samples = []
        self.groups = [(2*i, 2*i+1) for i in range(6)]  # [(0,1),(2,3),(4,5),(6,7),(8,9),(10,11)]

        for cls_idx, obj_list in self.objects_per_class.items():
            for obj_idx, (obj_id, paths) in enumerate(obj_list):
                # paths => v001..v012
                for r in range(self.num_repeat):
                    chosen_view_idxs = []
                    for g in self.groups:
                        chosen_view = random.choice(g)
                        if chosen_view >= len(paths):
                            chosen_view = g[0]  # fallback
                        chosen_view_idxs.append(chosen_view)

                    self.samples.append((cls_idx, obj_idx, chosen_view_idxs))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        cls_idx, obj_idx, chosen_view_idxs = self.samples[idx]
        obj_id, paths = self.objects_per_class[cls_idx][obj_idx]
        
        client_images = []
        for view_i in chosen_view_idxs:
            img_path = paths[view_i]
            img = Image.open(img_path).convert('L')
            if self.transform:
                img = self.transform(img)
            client_images.append(img)
        
        #(6,1,H,W)
        views_tensor = torch.stack(client_images, dim=0)
        
        return views_tensor, cls_idx

def compute_mean_std_for_modelnet10(root_dir, split='train', num_views=12, max_samples=2000):
    from torchvision import transforms
    temp_transform = transforms.Compose([
            transforms.Resize((32,32)),
            transforms.ToTensor()
        ])
    temp_dataset = ModelNet10MultiViewDataset(
        root_dir=root_dir,
        split=split,
        transform=temp_transform,
        num_repeat=1 
    )

    loader = torch.utils.data.DataLoader(temp_dataset, batch_size=1, shuffle=True)
    
    n_pixels = 0
    channel_sum = torch.zeros(1)
    channel_sum_sq = torch.zeros(1)

    count = 0
    for views_tensor, _ in loader:
        views_tensor = views_tensor.squeeze(0)  # (6,1,H,W)
        channel_sum += views_tensor.sum(dim=[0,2,3])
        channel_sum_sq += (views_tensor**2).sum(dim=[0,2,3])
        n_pixels += views_tensor.size(0) * views_tensor.size(2) * views_tensor.size(3)

        count += 1
        if count >= max_samples:
            break

    mean = channel_sum / n_pixels
    var  = (channel_sum_sq / n_pixels) - (mean**2)
    std  = torch.sqrt(var)

    return mean, std

        
def generate_mask_patterns(num_samples, batch_size, num_blocks, p_observed, p_aligned=0.0):
    num_batches = (num_samples + batch_size - 1) // batch_size
    num_aligned_batches = int(num_batches * p_aligned)
    num_masked_batches = num_batches - num_aligned_batches

    full_pattern = np.array([True]*num_blocks, dtype=bool)
    patterns = [np.array([bool(int(x)) for x in bin(i)[2:].zfill(num_blocks)]) for i in range(2**num_blocks)]
    
    probabilities = [p_observed**pattern.sum() * (1-p_observed)**(num_blocks - pattern.sum()) for pattern in patterns]
    probabilities = np.array(probabilities) / np.sum(probabilities)

    chosen_patterns = np.random.choice(len(patterns), size=num_masked_batches, p=probabilities)
    masked_batch_patterns = [patterns[i] for i in chosen_patterns]
    aligned_batch_patterns = [full_pattern for _ in range(num_aligned_batches)]
    all_batch_patterns = aligned_batch_patterns + masked_batch_patterns
    np.random.shuffle(all_batch_patterns)
    
    sample_patterns = []
    for i in range(num_samples):
        batch_idx = i // batch_size
        sample_patterns.append(all_batch_patterns[batch_idx])
        
    return sample_patterns


def collate_fn(batch):
    # batch[i]: (x, y, mask) shape(x)=(num_blocks,H,W)
    xs, ys, masks = zip(*batch)
    xs = torch.stack(xs, dim=0)     # (B, num_blocks, H, W)
    ys = torch.tensor(ys, dtype=torch.long)
    masks = torch.stack(masks, dim=0)  # (B, num_blocks)
    
    return xs, ys, masks

class CustomDataset(Dataset):
    def __init__(self, data, sample_patterns, unlabeled=False):
        self.data = data
        self.sample_patterns = sample_patterns
        self.unlabeled = unlabeled
        if isinstance(data, Subset):
            self.classes = getattr(data.dataset, 'classes', None)
        else:
            self.classes = getattr(data, 'classes', None)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        data = self.data[idx]

        features, label = data[:-1], data[-1]
        mask = self.sample_patterns[idx]

        if self.unlabeled:
            label = -1
        
        return (*features, label, torch.tensor(mask, dtype=torch.bool))


def get_idx_to_partition_map(dataset: str, num_clients: int) -> dict:
    if dataset in ["fashionmnist"] and num_clients == 8:
        return  {
                0: ((0, 14), (0, 7)),
                1: ((0, 14), (7, 14)),
                2: ((0, 14), (14, 21)),
                3: ((0, 14), (21, 28)),
                4: ((14, 28), (0, 7)),
                5: ((14, 28), (7, 14)),
                6: ((14, 28), (14, 21)),
                7: ((14, 28), (21, 28)),
                }
    elif dataset in ["fashionmnist"] and num_clients == 4:
        return  {
                0: ((0, 14), (0, 14)),
                1: ((0, 14), (14, 28)),
                2: ((14, 28), (0, 14)),
                3: ((14, 28), (14, 28))
                }
    else:
        raise NotImplementedError


def apply_missing_to_sample_mcar(x, p_mcar, n_clients):
    mask = [True]*n_clients
    for b in range(n_clients):
        if random.random()< p_mcar:
            mask[b]=False
    return torch.tensor(mask,dtype=torch.bool)


def apply_missing_to_sample_mar(x, dataset, n_clients, option, init_cut_var=0.7):
    mask = [True]*n_clients
    cut_var = init_cut_var
    
    start = random.randrange(n_clients)
    leftover = list(range(n_clients))
    leftover.remove(start)
    random.shuffle(leftover)
    order = [start] + leftover
    var_threshold = 1.1
    
    for idx, client_id in enumerate(order):
        if dataset == 'modelnet10':
            x_block = x[client_id]
        elif dataset in ['hapt', 'isolet']:
            x_block = x[client_id]
        elif dataset in ['mnist', 'fashionmnist']:
            row_indices, col_indices = get_idx_to_partition_map(dataset, n_clients)[client_id]
            r0, r1 = row_indices
            c0, c1 = col_indices
    
            x_block = x[:, r0:r1, c0:c1]  # shape (C, H', W')
        var_val = x_block.var().item() 

        if option == 0:
            if var_val > var_threshold:
                for c in order[idx+1:]:
                    mask[c] = False
                break
            else:
                var_threshold -= 0.15

        elif option == 1:
            cut_var -= max(var_val - var_threshold + 0.5, 0)
            if cut_var <= 0:
                for c in order[idx+1:]:
                    mask[c] = False
                break
            else:
                var_threshold -= 0.15

    return torch.tensor(mask, dtype=torch.bool)

def apply_missing_to_sample_mnar(x, dataset, n_clients, p_miss=0.7):
    mask = [True]*n_clients
    order = list(range(n_clients))

    for client_id in order:
        if dataset == 'modelnet10':
            x_block = x[client_id]
        elif dataset in ['hapt', 'isolet']:
            x_block = x[client_id]
        elif dataset in ['mnist', 'fashionmnist']:
            row_indices, col_indices = get_idx_to_partition_map(dataset, n_clients)[client_id]
            r0, r1 = row_indices
            c0, c1 = col_indices
    
            x_block = x[:, r0:r1, c0:c1]  # shape (C, H', W')
        mean_val = x_block.mean().item()  

        if mean_val < 0:
            if random.random() < p_miss:
                mask[client_id] = False

        else:
            if random.random() < 1-p_miss:
                mask[client_id] = False

    return torch.tensor(mask, dtype=torch.bool)


class MissingAppliedDataset_sample(Dataset):
    def __init__(self, base_dataset, dataset, n_clients, missing_type='mcar', p_mcar=0.5, p_miss=0.7, option=0, init_cut_var=0.7):
        super().__init__()
        self.base_dataset = base_dataset
        self.missing_type = missing_type

        # precompute
        self.masks = []
        for i in range(len(base_dataset)):
            x,y = base_dataset[i]
            if self.missing_type == 'mcar':
                mask = apply_missing_to_sample_mcar(x, p_mcar, n_clients)
            elif self.missing_type == 'mar':
                mask = apply_missing_to_sample_mar(x, dataset, n_clients, option, init_cut_var)
            elif self.missing_type == 'mnar':
                mask = apply_missing_to_sample_mnar(x, dataset, n_clients, p_miss)
            else:
                raise ValueError("Unknown missing type")
            self.masks.append(mask)

    def __len__(self):
        return len(self.base_dataset)
    def __getitem__(self, idx):
        x,y = self.base_dataset[idx]
        return x,y,self.masks[idx]


class FinalMissingDataset(Dataset):
    def __init__(self, base_ds, mask_array):
        super().__init__()
        self.base_ds = base_ds
        self.mask_array=mask_array
    def __len__(self):
        return len(self.base_ds)
    def __getitem__(self, idx):
        x,y=self.base_ds[idx]
        mask=self.mask_array[idx]
        return x,y,mask
        