import os
import json
import torch
import random
from collections import defaultdict
import numpy as np
from torch_geometric.utils import subgraph

def load_dataset(name, train_ratio, up_ratio=None, perturb_ratio=None):
    data_file = f"processed_data/{name}.pt"
    data = torch.load(data_file, weights_only=False)
    
    if not hasattr(data, 'split'):
        data = get_data_split(data)
        torch.save(data, data_file)
        
    if up_ratio == None:
        if perturb_ratio == None:
            data.edge_index_train = data.split['edge_index_train']
        else:
            if not hasattr(data, 'edge_index_pertubed'):
                perturb_edge_index(data)
            data.edge_index_train = data.edge_index_pertubed[perturb_ratio]
            
        data.edge_index_test = data.split['edge_index_test']
        data.pretrain_mask = data.split['pretrain']
        data.train_mask = data.split['train'][train_ratio]
        data.valid_mask = data.split['valid'][train_ratio]
        data.test_mask = data.split['test'][train_ratio]
    else:
        data.pretrain_mask = data.split['pretrain']
        if not hasattr(data, 'ood_split') or 1:
            get_ood_split(data)
        data.edge_index_train = data.ood_splits[up_ratio]['edge_index_train']
        data.edge_index_test = data.ood_splits[up_ratio]['edge_index_test']
        data.train_mask = data.ood_splits[up_ratio]['train'][train_ratio]
        data.valid_mask = data.ood_splits[up_ratio]['valid'][train_ratio]
        data.test_mask = data.ood_splits[up_ratio]['test'][train_ratio]

    return data


def get_data_split(data):
    n = len(data.y)
    pretrain_ratio = 20
    train_ratios = [1, 2, 5, 10, 20, 30]
    valid_ratio = 20
    test_ratio = 30
    n_pretrain = int(n*pretrain_ratio/100)
    n_train_list = [int(n*ratio/100) for ratio in train_ratios]
    n_valid = int(n*valid_ratio/100)
    n_test = n-max(n_train_list)-n_valid-n_pretrain
    all_indexs = list(range(n))
    random.shuffle(all_indexs)
    pretrain_idx = torch.tensor(all_indexs[:n_pretrain])
    train_idx_max = all_indexs[n_pretrain:max(n_train_list)+n_pretrain]  
    train_idx_dict = {}

    valid_idx = all_indexs[n_pretrain+max(n_train_list):n_pretrain+max(n_train_list)+n_valid]

    train_valid_idx = train_idx_max + valid_idx
    edge_index_train, _ = subgraph(
        subset=train_valid_idx,
        edge_index=data.edge_index,
        relabel_nodes=False,  # keep original indices if you like
        num_nodes=n
    )
    
    test_idx = all_indexs[n_pretrain+max(n_train_list)+n_valid:]
    edge_index_test, _ = subgraph(
        subset=test_idx,
        edge_index=data.edge_index,
        relabel_nodes=False,  # keep original indices if you like
        num_nodes=n
    )

    train_idx_dict = {}
    valid_idx_dict = {}
    test_idx_dict = {}
    
    for ratio in train_ratios:
        train_idx_dict[ratio] = torch.tensor(train_idx_max[:int(n*ratio/100)])
        valid_idx_dict[ratio] = valid_idx
        test_idx_dict[ratio] = test_idx
        
    split = {'edge_index_train': edge_index_train,  'edge_index_test': edge_index_test, 'pretrain':pretrain_idx,
             'train': train_idx_dict, 
             'valid': valid_idx_dict, 
             'test': test_idx_dict}
    
    data.split = split
    return data

def get_ood_split(data):
    if isinstance(data.y, torch.Tensor):
        n = data.y.size()[0]
        n_classes = int(data.y.max())+1
    else:
        n = len(data.y)
        n_classes = int(max(data.y))+1
    pretrain_idx = data.pretrain_mask
    all_indices = list(range(n))
    remaining = list(set(all_indices) - set(pretrain_idx))
    random.shuffle(remaining)

    # Step 1: get 50% for training + validation (skewed)
    train_idx_dict = defaultdict(dict)
    valid_idx_dict = defaultdict(dict)
    test_idx_dict = defaultdict(dict)
    
    total_train_valid_n = int(n * 0.5)

    ood_splits = {}
    for up_ratio in (2, 3, 5, 10):
        
        train_valid_pool = custom_class_balanced_sample(remaining, data, n_classes, total_train_valid_n, up_ratio=up_ratio)
        edge_index_train, _ = subgraph(
            subset=train_valid_pool,
            edge_index=data.edge_index,
            relabel_nodes=False, # keep original indices if you like
            num_nodes=n
        )
        
        # Step 2: split into fixed 20% valid and 30% pool for train
        n_valid = int(n * 0.2)
        n_train_pool = total_train_valid_n - n_valid
    
        random.shuffle(train_valid_pool)
        valid_idx = train_valid_pool[:n_valid]
        train_pool = train_valid_pool[n_valid:]
    
        # Step 3: from train_pool, get each train split
        train_ratios=[1, 2, 5, 10, 20, 30]
        

    
        # Step 4: remaining indices → test
        used_idx = set(pretrain_idx) | set(train_valid_pool)
        test_idx = list(set(all_indices) - used_idx)

        edge_index_test, _ = subgraph(
            subset=test_idx,
            edge_index=data.edge_index,
            relabel_nodes=False,
            num_nodes=n)

        train_idx_dict[up_ratio] = {}
        valid_idx_dict[up_ratio] = {}
        test_idx_dict[up_ratio] = {}
        for ratio in train_ratios:
            n_train = int(n * ratio / 100)
            # print (n_train, len(train_pool))
            if ratio<30:
                train_idx = random.sample(train_pool, n_train)
            else:
                train_idx = train_pool
            train_idx_dict[ratio] = train_idx
            valid_idx_dict[ratio] = valid_idx
            test_idx_dict[ratio] = test_idx
        

        ood_split = {'edge_index_train': edge_index_train,  'edge_index_test': edge_index_test, 
                     'train': train_idx_dict, 
                     'valid': valid_idx_dict, 
                     'test': test_idx_dict}
        
        ood_splits[up_ratio] = ood_split 
    
    data.ood_splits = ood_splits
    return data


def custom_class_balanced_sample(all_idx, data, n_classes, n_total, up_ratio):
    down_ratio = 1
    half = n_classes // 2
    up_classes = list(range(half))
    down_classes = list(range(half, n_classes))

    class_to_idx = defaultdict(list)
    for idx in all_idx:
        if isinstance(data.y, torch.Tensor):
            class_to_idx[data.y[idx].item()].append(idx)
        else:
            class_to_idx[data.y[idx]].append(idx)

    total_ratio = up_ratio * len(up_classes) + down_ratio * len(down_classes)
    per_unit = n_total / total_ratio

    sampled_idx = []
    for c in up_classes:
        n_c = int(round(per_unit * up_ratio))
        idxs = class_to_idx[c]
        if len(idxs) < n_c:
            sampled = np.random.choice(idxs, n_c, replace=True)
        else:
            sampled = np.random.choice(idxs, n_c, replace=False)
        sampled_idx.extend(sampled.tolist())

    for c in down_classes:
        n_c = int(round(per_unit * down_ratio))
        idxs = class_to_idx[c]
        if len(idxs) < n_c:
            sampled = np.random.choice(idxs, n_c, replace=True)
        else:
            sampled = np.random.choice(idxs, n_c, replace=False)
        sampled_idx.extend(sampled.tolist())

    return sampled_idx

def perturb_edge_index(data, exclude_self_loops=True):
    """
    Perturb the edge_index of a PyG Data object in-place.

    Args:
        data: PyG Data object
        drop_ratio: float, fraction of existing edges to drop
        add_ratio: float, fraction of new edges to add (relative to current number)
        exclude_self_loops: bool, avoid self-loops when adding edges
    """

    edge_index_dict = {}
    for ratio in (0.05, 0.1, 0.2, 0.3, 0.5):
        drop_ratio = ratio
        add_ratio = ratio
        edge_index = data.split['edge_index_train']
        num_edges = edge_index.size(1)
        num_nodes = data.num_nodes
    
        # Drop edges
        if drop_ratio > 0:
            num_drop = int(num_edges * drop_ratio)
            keep_mask = torch.ones(num_edges, dtype=torch.bool)
            drop_indices = torch.randperm(num_edges)[:num_drop]
            keep_mask[drop_indices] = False
            edge_index = edge_index[:, keep_mask]
    
        # Add edges
        if add_ratio > 0:
            num_add = int(edge_index.size(1) * add_ratio)
    
            # Make a set of existing edges for fast lookup
            existing_edges = set(map(tuple, edge_index.t().tolist()))
    
            added_edges = []
            attempts = 0
            max_attempts = num_add * 10  # safety to prevent infinite loops
    
            while len(added_edges) < num_add and attempts < max_attempts:
                src = torch.randint(0, num_nodes, (1,)).item()
                dst = torch.randint(0, num_nodes, (1,)).item()
                if exclude_self_loops and src == dst:
                    attempts += 1
                    continue
                e = (src, dst)
                if e in existing_edges or e in added_edges:
                    attempts += 1
                    continue
                added_edges.append(e)
                attempts += 1
    
            if added_edges:
                added_tensor = torch.tensor(added_edges).t().contiguous()
                edge_index = torch.cat([edge_index, added_tensor], dim=1)
                
        edge_index_dict[ratio] = edge_index

    data.edge_index_pertubed = edge_index_dict
    return data
