import torch
from torch_geometric.data import Data
from torch.utils.data import Dataset
import torch
from scipy.sparse import issparse
import random
from torch_geometric.utils import dropout_adj
from transformers import BertTokenizer
from torch_geometric.utils import dropout_edge
from collections import defaultdict


def collect_cross_dataset_samples(source_dataset_name: str,
                                  target_dataset_name: str,
                                  mode: str = 'zero-shot',
                                  seed: int = 0,
                                  k: int = 5,
                                  drop_edge_ratio: float = 0.0,
                                  drop_node_ratio: float = 0.0,
                                  text_mask_ratio: float = 0.0,
                                  use_text: bool = True):
    torch.manual_seed(seed)
    random.seed(seed)

    source_data = load_dataset(source_dataset_name,
                               use_text=use_text,
                               seed=seed,
                               drop_edge_ratio=drop_edge_ratio,
                               drop_node_ratio=drop_node_ratio,
                               text_mask_ratio=text_mask_ratio)

    target_data = load_dataset(target_dataset_name,
                               use_text=use_text,
                               seed=seed,
                               drop_edge_ratio=drop_edge_ratio,
                               drop_node_ratio=drop_node_ratio,
                               text_mask_ratio=text_mask_ratio)

    if mode == 'zero-shot':
        source_train_idx = source_data.train_mask.nonzero(as_tuple=True)[0]
        train_dataset = ListDataset([source_data for _ in source_train_idx])
        test_dataset = ListDataset([target_data])

    elif mode == 'full-supervised':
        target_train_idx = target_data.train_mask.nonzero(as_tuple=True)[0]
        train_dataset = ListDataset([target_data for _ in target_train_idx])
        test_dataset = ListDataset([target_data])

    elif mode == 'few-shot':
        y = target_data.y
        train_mask = target_data.train_mask
        selected_indices = []
        label_to_indices = defaultdict(list)

        for idx in train_mask.nonzero(as_tuple=True)[0].tolist():
            label = y[idx].item()
            label_to_indices[label].append(idx)

        for label, indices in label_to_indices.items():
            sampled = random.sample(indices, min(k, len(indices)))
            selected_indices.extend(sampled)

        fewshot_mask = torch.zeros_like(train_mask)
        fewshot_mask[selected_indices] = True
        target_data.train_mask = fewshot_mask

        train_dataset = ListDataset([target_data for _ in selected_indices])
        test_dataset = ListDataset([target_data])

    else:
        raise ValueError(f"Unsupported mode: {mode}")

    return train_dataset, test_dataset


def dropedge(data, drop_ratio=0.0, mask=None):
    if drop_ratio == 0.0:
        return data
    edge_index = data.edge_index

    if mask is not None:
        train_ids = mask.nonzero(as_tuple=True)[0]
        train_ids_set = set(train_ids.tolist())

        src, dst = edge_index
        train_edge_mask = torch.tensor(
            [(s.item() in train_ids_set) or (d.item() in train_ids_set)
             for s, d in zip(src, dst)],
            device=edge_index.device
        )
        train_edges = edge_index[:, train_edge_mask]

        test_edge_mask = ~train_edge_mask
        test_edges = edge_index[:, test_edge_mask]

        train_edges, _ = dropout_adj(train_edges, p=drop_ratio, training=True)

        edge_index = torch.cat([train_edges, test_edges], dim=1)

    else:
        edge_index, _ = dropout_adj(edge_index, p=drop_ratio, training=True)

    data.edge_index = edge_index
    return data


def dropnode(data, drop_ratio=0.0, mask=None):
    if drop_ratio == 0.0:
        return data
    if mask is not None:
        node_ids = mask.nonzero(as_tuple=True)[0]
    else:
        node_ids = torch.arange(data.x.size(0), device=data.x.device)
    rand_mask = torch.rand(node_ids.size(0), device=data.x.device) < drop_ratio
    drop_indices = node_ids[rand_mask]
    if hasattr(data, 'train_mask'):
        data.train_mask[drop_indices] = False
    else:
        data.train_mask = torch.ones(data.x.size(0), dtype=torch.bool, device=data.x.device)
        data.train_mask[drop_indices] = False
    data.drop_indices = drop_indices
    return data


def textmask(data, mask_ratio=0.0, mask=None):
    if mask_ratio == 0.0:
        return data
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    masked_texts = []

    if mask is not None:
        mask_indices = mask.nonzero(as_tuple=True)[0].tolist()
    else:
        mask_indices = list(range(len(data.x_text)))

    for i, text in enumerate(data.x_text):
        if i in mask_indices:
            words = text.split()
            mask_num = int(len(words) * mask_ratio)
            for j in random.sample(range(len(words)), mask_num):
                words[j] = tokenizer.mask_token
            masked_texts.append(' '.join(words))
        else:
            masked_texts.append(text)

    data.x_text = masked_texts
    return data


def load_dataset(dataset_name, use_text=True, seed=0,
                 drop_edge_ratio=0.0,
                 drop_node_ratio=0.0,
                 text_mask_ratio=0.0,
                 shot_mode='full-supervised'):
    if dataset_name == 'pubmed':
        from core.data_utils.dataset_pubmed import get_raw_text_pubmed
        raw_data = get_raw_text_pubmed(use_text=use_text, seed=seed)
    elif dataset_name == 'cora':
        from core.data_utils.dataset_cora import get_raw_text_cora
        raw_data = get_raw_text_cora(use_text=use_text, seed=seed)
    elif dataset_name == 'ogbn-arxiv':
        from core.data_utils.dataset_arxiv import get_raw_text_arxiv
        raw_data = get_raw_text_arxiv(use_text=use_text, seed=seed)
    elif dataset_name == 'ogbn-products':
        from core.data_utils.dataset_products import get_raw_text_products
        raw_data = get_raw_text_products(use_text=use_text, seed=seed)
    elif dataset_name == 'arxiv_2023':
        from core.data_utils.dataset_arxiv_2023 import get_raw_text_arxiv_2023
        raw_data = get_raw_text_arxiv_2023(use_text=use_text, seed=seed)
    elif dataset_name == 'citeseer':
        from core.data_utils.dataset_citeseer import get_raw_text_citeseer
        raw_data = get_raw_text_citeseer(use_text=use_text, seed=seed)
    elif dataset_name == 'wikics':
        from core.data_utils.dataset_wikics import get_raw_text_wikics
        raw_data = get_raw_text_wikics(use_text=use_text, seed=seed)
    elif dataset_name == 'history':
        from core.data_utils.dataset_history import get_raw_text_history
        raw_data = get_raw_text_history(use_text=use_text, seed=seed)
    elif dataset_name == 'computer':
        from core.data_utils.dataset_computer import get_raw_text_computer
        raw_data = get_raw_text_computer(use_text=use_text, seed=seed)
    elif dataset_name == 'photo':
        from core.data_utils.dataset_photo import get_raw_text_photo
        raw_data = get_raw_text_photo(use_text=use_text, seed=seed)
    else:
        raise ValueError(f"Dataset {dataset_name} not supported.")

    if dataset_name == 'ogbn-arxiv':
        edge_index = raw_data[0].adj_t
    else:
        edge_index = raw_data[0].edge_index
    if shot_mode == 'few-shot':
        y = raw_data[0]['y']
        train_mask = raw_data[0]['train_mask']
        unique_classes = torch.unique(y)
        new_train_mask = torch.zeros_like(train_mask, dtype=torch.bool)
        for cls in unique_classes:
            cls_indices = torch.nonzero((y == cls) & train_mask).squeeze()
            if len(cls_indices) >= 5:
                selected_indices = cls_indices[torch.randperm(len(cls_indices))[:5]]
                new_train_mask[selected_indices] = True
        raw_data[0]['train_mask'] = new_train_mask
    elif shot_mode == 'zero-shot':
        raw_data[0]['train_mask'] = torch.zeros_like(raw_data[0]['train_mask'], dtype=torch.bool)
    data = Data(
        x=raw_data[0].x if raw_data[0].x is not None else torch.zeros((raw_data[0].num_nodes, 1)),
        edge_index=edge_index,
        y=raw_data[0].y,
        x_text=raw_data[1],
        train_mask=raw_data[0].train_mask,
        val_mask=raw_data[0].val_mask,
        test_mask=raw_data[0].test_mask,
        train_edge_index=raw_data[0].train_edge_index,
        val_edge_index=raw_data[0].val_edge_index,
        test_edge_index=raw_data[0].test_edge_index,
        train_edge_label_index=raw_data[0].train_edge_label_index,
        val_edge_label_index=raw_data[0].val_edge_label_index,
        test_edge_label_index=raw_data[0].test_edge_label_index,
        train_edge_label=raw_data[0].train_edge_label,
        val_edge_label=raw_data[0].val_edge_label,
        test_edge_label=raw_data[0].test_edge_label,
        label=raw_data[2]
    )
    if drop_edge_ratio > 0:
        data = dropedge(data, drop_edge_ratio, mask=data.train_mask)
        data.train_edge_index, edge_mask = dropout_edge(data.train_edge_index, p=drop_edge_ratio)
        data.train_edge_label_index = data.train_edge_label_index[:, edge_mask]

    if drop_node_ratio > 0:
        data = dropnode(data, drop_node_ratio, mask=data.train_mask)

    if text_mask_ratio > 0 and use_text:
        data = textmask(data, text_mask_ratio, mask=data.train_mask)

    return data


class ListDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]