import os
import numpy as np
import pandas as pd
import torch
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import WeightedRandomSampler
from transformers import BertTokenizer

from copy import deepcopy

class MultiNLIDataset(Dataset):
    def __init__(self, basedir, split="train", transform=None, balance_groups=False, seed=0, rate=1.0):
        try:
            split_i = ["train", "val", "test"].index(split)
        except ValueError:
            raise(f"Unknown split {split}")

        metadata_df = pd.read_csv(os.path.join(basedir, "metadata_random.csv"))
        splits = np.asarray(metadata_df["split"].values)
        self.data_indices = np.argwhere(splits == split_i).flatten()

        metadata_df = metadata_df[metadata_df["split"] == split_i]

        # metadata_df = metadata_df.reset_index(drop=True)
        # metadata_df["feature_indx"] = np.arange(len(metadata_df))

        self.transform = transform
        self.basedir = basedir
        self.seed = seed
        self.rate = rate
        self.tokenizer = BertTokenizer.from_pretrained("./bert-base-uncased")

        self.metadata_df = metadata_df.copy()

        # 原始标签与分组信息
        self.y_array = self.metadata_df['gold_label'].values
        self.p_array = self.metadata_df['sentence2_has_negation'].values
        self.confounder_array = self.p_array
        self.n_classes = np.unique(self.y_array).size
        self.n_places = np.unique(self.p_array).size
        self.group_array = (self.y_array * self.n_places + self.p_array).astype('int')
        self.n_groups = self.n_classes * self.n_places
        self.features_array = []

        self.features_array = []
        for feature_file in [
            'cached_train_bert-base-uncased_128_mnli',
            'cached_dev_bert-base-uncased_128_mnli',
            'cached_dev_bert-base-uncased_128_mnli-mm'
        ]:
            features = torch.load(os.path.join(basedir,feature_file))
            self.features_array += features

        self.all_input_ids = torch.tensor([f.input_ids for f in self.features_array], dtype=torch.long)
        self.all_input_masks = torch.tensor([f.input_mask for f in self.features_array], dtype=torch.long)
        self.all_segment_ids = torch.tensor([f.segment_ids for f in self.features_array], dtype=torch.long)
        self.all_label_ids = torch.tensor([f.label_id for f in self.features_array], dtype=torch.long)

        self.x_array = torch.stack((
            self.all_input_ids,
            self.all_input_masks,
            self.all_segment_ids), dim=2)[self.data_indices]

        # 更新计数信息
        self.group_counts = (
            torch.arange(self.n_groups).unsqueeze(1) == torch.from_numpy(self.group_array)
        ).sum(1).float()
        self.y_counts = (
            torch.arange(self.n_classes).unsqueeze(1) == torch.from_numpy(self.y_array)
        ).sum(1).float()
        self.p_counts = (
            torch.arange(self.n_places).unsqueeze(1) == torch.from_numpy(self.p_array)
        ).sum(1).float()

        self.filename_array = self.metadata_df['Unnamed: 0'].values

    def __len__(self):
        return len(self.metadata_df)

    def __getitem__(self, idx):
        y = self.y_array[idx]
        g = self.group_array[idx]
        p = self.confounder_array[idx]
        x = self.x_array[idx, ...]

        return x, y, g, p

class JigsawDataset(Dataset):
    def __init__(self, basedir, split="train", transform=None, balance_groups=False, seed=0, target_name = "toxicity"):
        try:
            split_i = ["train", "val", "test"].index(split)
        except ValueError:
            raise(f"Unknown split {split}")

        metadata_df = pd.read_csv(os.path.join(basedir, "all_data_with_identities.csv"))
        #metadata_df.loc[ metadata_df["split"] == split, "split" ] = split_i

        metadata_df = metadata_df[metadata_df["split"] == split]

        self.transform = transform
        self.basedir = basedir
        self.seed = seed
        self.target_name = target_name
        self.confounder_names = ["identity_any"]#, "white"]
        self.metadata_df = metadata_df.copy()

        self.y_array = (self.metadata_df[self.target_name].values >= 0.5).astype("long")
        self.n_classes = len(np.unique(self.y_array))

        if self.confounder_names[0] == "only_label":
            self.n_groups = self.n_classes
            self.group_array = self.y_array
        else:
            # Confounders are all binary
            # Map the confounder attributes to a number 0,...,2^|confounder_idx|-1
            self.n_confounders = len(self.confounder_names)
            confounders = (self.metadata_df.loc[:, self.confounder_names] >= 0.4).values
            self.confounder_array = confounders @ np.power(
                2, np.arange(self.n_confounders)
            )

            # Map to groups
            self.n_groups = self.n_classes * pow(2, self.n_confounders)
            self.group_array = (
                self.y_array * (self.n_groups / 2) + self.confounder_array
            ).astype("int")

            self.p_array = self.confounder_array
            self.n_places = np.unique(self.p_array).size

        self.group_counts = (
            torch.arange(self.n_groups).unsqueeze(1) == torch.from_numpy(self.group_array)
        ).sum(1).float()
        self.y_counts = (
            torch.arange(self.n_classes).unsqueeze(1) == torch.from_numpy(self.y_array)
        ).sum(1).float()
        # self.p_counts = (
        #     torch.arange(self.n_places).unsqueeze(1) == torch.from_numpy(self.p_array)
        # ).sum(1).float()

        # self.filename_array = self.metadata_df['img_filename'].values
        self.text_array = list(metadata_df["comment_text"])
        self.tokenizer = BertTokenizer.from_pretrained("./bert-base-uncased")

    def __len__(self):
        return len(self.metadata_df)

    def __getitem__(self, idx):
        y = self.y_array[idx]
        g = self.group_array[idx]
        p = self.confounder_array[idx]

        text = str(self.text_array[idx])
        tokens = self.tokenizer(
            text,
            padding="max_length",
            truncation=True,
            max_length=300,
            return_tensors="pt",
        )  # 220
        x = torch.stack(
            (tokens["input_ids"], tokens["attention_mask"], tokens["token_type_ids"]),
            dim=2,
        )
        x = torch.squeeze(x, dim=0)  # First shape dim is always 1

        return x, y, g, p



class WaterBirdsDataset(Dataset):
    def __init__(self, basedir, split="train", transform=None, balance_groups=False, seed=0):
        try:
            split_i = ["train", "val", "test"].index(split)
        except ValueError:
            raise(f"Unknown split {split}")

        metadata_df = pd.read_csv(os.path.join(basedir, "metadata.csv"))
        metadata_df = metadata_df[metadata_df["split"] == split_i]

        self.transform = transform
        self.basedir = basedir
        self.seed = seed

        self.metadata_df = metadata_df.copy()

        # 原始标签与分组信息
        self.y_array = self.metadata_df['y'].values
        self.p_array = self.metadata_df['place'].values
        self.confounder_array = self.p_array
        self.n_classes = np.unique(self.y_array).size
        self.n_places = np.unique(self.p_array).size
        self.group_array = (self.y_array * self.n_places + self.p_array).astype('int')
        self.n_groups = self.n_classes * self.n_places

        # 截断平衡逻辑
        if balance_groups:
            print("Balancing groups by truncation...")
            np.random.seed(self.seed)
            group_indices = []
            for g in range(self.n_groups):
                idx_g = np.where(self.group_array == g)[0]
                np.random.shuffle(idx_g)
                group_indices.append(idx_g)
            min_size = min(len(g) for g in group_indices)
            print(f"Using {min_size} samples per group")

            balanced_idx = np.concatenate([g[:min_size] for g in group_indices])
            self.metadata_df = self.metadata_df.iloc[balanced_idx].reset_index(drop=True)
            self.y_array = self.y_array[balanced_idx]
            self.p_array = self.p_array[balanced_idx]
            self.confounder_array = self.confounder_array[balanced_idx]
            self.group_array = self.group_array[balanced_idx]

        # 更新计数信息
        self.group_counts = (
            torch.arange(self.n_groups).unsqueeze(1) == torch.from_numpy(self.group_array)
        ).sum(1).float()
        self.y_counts = (
            torch.arange(self.n_classes).unsqueeze(1) == torch.from_numpy(self.y_array)
        ).sum(1).float()
        self.p_counts = (
            torch.arange(self.n_places).unsqueeze(1) == torch.from_numpy(self.p_array)
        ).sum(1).float()

        self.filename_array = self.metadata_df['img_filename'].values

    def __len__(self):
        return len(self.metadata_df)

    def __getitem__(self, idx):
        y = self.y_array[idx]
        g = self.group_array[idx]
        p = self.confounder_array[idx]

        img_path = os.path.join(self.basedir, self.filename_array[idx])
        img = Image.open(img_path).convert('RGB')

        if self.transform:
            img = self.transform(img)
        return img, y, g, p


class CombinedWaterBirdsDataset(Dataset):
    def __init__(self, datasets):
        # 默认取第一个数据集的 transform 和 basedir
        # self.transform = datasets[0].transform
        # self.basedir = datasets[0].basedir
        # self.seed = datasets[0].seed

        # 合并多个数据集的 metadata_df 和其他属性
        # self.metadata_df = pd.concat([d.metadata_df for d in datasets], ignore_index=True)
        self.x = np.concatenate([d.x for d in datasets])
        self.y_array = np.concatenate([d.y_array for d in datasets])
        self.p_array = np.concatenate([d.p_array for d in datasets])
        self.confounder_array = np.concatenate([d.confounder_array for d in datasets])
        self.group_array = np.concatenate([d.group_array for d in datasets])
        # self.filename_array = np.concatenate([d.filename_array for d in datasets])

        # # ---- 去重逻辑开始 ----
        # # 构建唯一标识（这里使用 x + y + group）进行去重
        # key_array = [str(self.x[i].tobytes()) + f"_{self.y_array[i]}_{self.group_array[i]}"
        #              for i in range(len(self.y_array))]
        #
        # # 用字典保留第一次出现的位置（也可用 pandas 方式）
        # seen = {}
        # unique_indices = []
        # for i, key in enumerate(key_array):
        #     if key not in seen:
        #         seen[key] = True
        #         unique_indices.append(i)
        #
        # self.x = self.x[unique_indices]
        # self.y_array = self.y_array[unique_indices]
        # self.p_array = self.p_array[unique_indices]
        # self.confounder_array = self.confounder_array[unique_indices]
        # self.group_array = self.group_array[unique_indices]

        # 计算数据集的统计信息
        self.n_classes = np.unique(self.y_array).size
        self.n_places = np.unique(self.p_array).size
        self.n_groups = self.n_classes * self.n_places

        # 更新计数信息
        self.group_counts = (
                torch.arange(self.n_groups).unsqueeze(1) == torch.from_numpy(self.group_array)
        ).sum(1).float()

        self.y_counts = (
                torch.arange(self.n_classes).unsqueeze(1) == torch.from_numpy(self.y_array)
        ).sum(1).float()

        self.p_counts = (
                torch.arange(self.n_places).unsqueeze(1) == torch.from_numpy(self.p_array)
        ).sum(1).float()

    def __len__(self):
        return len(self.y_array)

    def __getitem__(self, idx):
        y = self.y_array[idx]
        g = self.group_array[idx]
        p = self.confounder_array[idx]
        x = self.x[idx]
        return x, y, g, p

def get_transform_cub(target_resolution, train, augment_data):
    scale = 256.0 / 224.0

    if (not train) or (not augment_data):
        # Resizes the image to a slightly larger square then crops the center.
        transform = transforms.Compose([
            transforms.Resize((int(target_resolution[0]*scale), int(target_resolution[1]*scale))),
            transforms.CenterCrop(target_resolution),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    else:
        transform = transforms.Compose([
            transforms.RandomResizedCrop(
                target_resolution,
                scale=(0.7, 1.0),
                ratio=(0.75, 1.3333333333333333),
                interpolation=2),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    return transform


def get_loader(data, train, reweight_groups, reweight_classes, reweight_places, **kwargs):
    if not train: # Validation or testing
        assert reweight_groups is None
        assert reweight_classes is None
        assert reweight_places is None
        shuffle = False
        sampler = None
    elif not (reweight_groups or reweight_classes or reweight_places): # Training but not reweighting
        shuffle = True
        sampler = None
    elif reweight_groups:
        # Training and reweighting groups
        # reweighting changes the loss function from the normal ERM (average loss over each training example)
        # to a reweighted ERM (weighted average where each (y,c) group has equal weight)
        group_weights = len(data) / data.group_counts
        weights = group_weights[data.group_array]

        # Replacement needs to be set to True, otherwise we'll run out of minority samples
        sampler = WeightedRandomSampler(weights, len(data), replacement=True)
        shuffle = False
    elif reweight_classes:  # Training and reweighting classes
        class_weights = len(data) / data.y_counts
        weights = class_weights[data.y_array]
        sampler = WeightedRandomSampler(weights, len(data), replacement=True)
        shuffle = False
    else: # Training and reweighting places
        place_weights = len(data) / data.p_counts
        weights = place_weights[data.p_array]
        sampler = WeightedRandomSampler(weights, len(data), replacement=True)
        shuffle = False

    loader = DataLoader(
        data,
        shuffle=shuffle,
        sampler=sampler,
        **kwargs)
    return loader


def log_data(logger, train_data, test_data, val_data=None, get_yp_func=None):
    logger.write(f'Training Data (total {len(train_data)})\n')
    # group_id = y_id * n_places + place_id
    # y_id = group_id // n_places
    # place_id = group_id % n_places
    for group_idx in range(train_data.n_groups):
        y_idx, p_idx = get_yp_func(group_idx)
        logger.write(f'    Group {group_idx} (y={y_idx}, p={p_idx}): n = {train_data.group_counts[group_idx]:.0f}\n')
    logger.write(f'Test Data (total {len(test_data)})\n')
    for group_idx in range(test_data.n_groups):
        y_idx, p_idx = get_yp_func(group_idx)
        logger.write(f'    Group {group_idx} (y={y_idx}, p={p_idx}): n = {test_data.group_counts[group_idx]:.0f}\n')
    if val_data is not None:
        logger.write(f'Validation Data (total {len(val_data)})\n')
        for group_idx in range(val_data.n_groups):
            y_idx, p_idx = get_yp_func(group_idx)
            logger.write(f'    Group {group_idx} (y={y_idx}, p={p_idx}): n = {val_data.group_counts[group_idx]:.0f}\n')
