import os
import os.path as osp
import shutil
import numpy as np
import torch
from torch_geometric.datasets import TUDataset
from torch_geometric.data import InMemoryDataset


from sklearn.model_selection import train_test_split
from typing import Tuple
import random


def split_tu_dataset(ds_name: str, root: str, test_size: float = 0.2, force_reload: bool = True, 
                     seed=1111, mode='single', 
                     k_shot: int = 5, 
                     return_indices=False) -> Tuple[str, str]:
    """
    Split graph dataset into training and test sets according to specified ratio
    
    Parameters:
        `ds_name`: Dataset name
        `root`: Save directory
        `test_size`: Test set ratio, default 0.2
        `force_reload`: Whether to force reload, default False
        `use_decomp`: Decomposition method, default 'all_graphs'
        `dim`: Decomposition dimension, default 10
        `pad_multi_n`: Padding multiple, default 4
        `mode`: 'single' or 'fewshot', default 'single'
        `k_shot`: Number of shots, default 5
    
    Returns:
        Training path and test path
    """

    print(f'numpy seed: {np.random.get_state()[1][0]}')
    # Dataset path
    dataset_path = osp.join(root, ds_name)
    # Directly use final training and test directories
    train_data_path = osp.join(root, f"{ds_name}_train")
    test_data_path = osp.join(root, f"{ds_name}_test")
    
    # Raw directories for training and test sets
    train_raw_path = osp.join(train_data_path, 'raw')
    test_raw_path = osp.join(test_data_path, 'raw')
    
    # Original data directory
    raw_path = osp.join(dataset_path, 'raw')


    if not osp.exists(dataset_path):
        print(f"Downloading dataset {ds_name}...")
        # Use TUDataset to download dataset
        dataset = TUDataset(root=root, name=ds_name, force_reload=True)
        print(f"Dataset {ds_name} download completed")
    
    # If not yet split, split the dataset
    if not (osp.exists(train_raw_path) and osp.exists(test_raw_path)) or force_reload:
        print(f"Splitting dataset {ds_name}...")
        
        # Create raw directories for train and test
        os.makedirs(train_raw_path, exist_ok=True)
        os.makedirs(test_raw_path, exist_ok=True)

        if osp.exists(osp.join(train_data_path, 'processed')):
            shutil.rmtree(osp.join(train_data_path, 'processed'))
        if osp.exists(osp.join(test_data_path, 'processed')):
            shutil.rmtree(osp.join(test_data_path, 'processed'))
        
        # Read graph labels for splitting
        graph_labels_file = osp.join(raw_path, f"{ds_name}_graph_labels.txt")
        with open(graph_labels_file, 'r') as f:
            graph_labels = [int(line.strip()) for line in f]
        
        # 将graph_labels转换为numpy数组
        graph_labels = np.array(graph_labels, dtype=int)

        # Get graph indices and split by ratio
        indices = np.arange(len(graph_labels))

        if mode == 'single':
            train_indices, test_indices = train_test_split(
                indices, test_size=test_size, random_state=seed, 
                # stratify=graph_labels  # Ensure class proportions are consistent in both train and test sets
            )

            # convert train_indices and test_indices to integer array
            train_indices = np.array(train_indices, dtype=int)

            test_indices = np.array(test_indices, dtype=int)

            # check if train dataset contains all classes
            train_classes = set(graph_labels[train_indices])
            classes = set(graph_labels)
            if train_classes != classes:
                raise ValueError(f"Train dataset does not contain all classes: {train_classes} != {classes}")


        elif mode == 'fewshot':
            # randomly sample k_shot for each class
            # 2. 获取所有唯一的类别
            unique_classes = np.unique(graph_labels)
            
            train_indices = []
            test_indices = []
            
            # 3. 对每个类别进行k-shot采样
            for cls in unique_classes:
                # 找出该类别的所有样本索引
                cls_indices = np.where(graph_labels == cls)[0]
                # 如果该类别样本数小于k_shot，则全部加入训练集
                if len(cls_indices) <= k_shot:
                    print(f"Warning: Class {cls} has only {len(cls_indices)} samples, using all for training")
                    train_indices.extend(cls_indices)
                else:
                    # 随机选择k_shot个样本作为训练集

                    train_idx = np.random.choice(cls_indices, k_shot, replace=False)
                    # train_idx = random.sample(cls_indices, k=k_shot)
                    # 剩余样本作为测试集
                    remain_idx = np.array([idx for idx in cls_indices if idx not in train_idx])
                    test_idx = np.random.choice(remain_idx, size=int(1. *len(remain_idx)), replace=False)
                    train_indices.extend(train_idx)
                    test_indices.extend(test_idx)

            # convert train_indices and test_indices to integer array
            # print('train_indices: ', train_indices)
            train_indices = np.array(train_indices, dtype=int)
            test_indices = np.array(test_indices, dtype=int)
            print(f"Few-shot sampling: {k_shot} shots per class")
            print(f"Train samples: {len(train_indices)}, Test samples: {len(test_indices)}")
            
        train_indices = sorted(train_indices)
        test_indices = sorted(test_indices)
        # Read graph indicator file to determine which nodes belong to which graph
        if osp.exists(osp.join(raw_path, f"{ds_name}_graph_indicator.txt")):
            graph_indicator_file = osp.join(raw_path, f"{ds_name}_graph_indicator.txt")
        # elif osp.exists(osp.join(raw_path, f"{ds_name}_graph_attributes.txt")):
        #     graph_indicator_file = osp.join(raw_path, f"{ds_name}_graph_attributes.txt")
        else:
            raise RuntimeError('graph indicators do not exist')

        with open(graph_indicator_file, 'r') as f:
            graph_indicator = [int(line.strip()) for line in f]  # Original graph indicator, starting from 1
        
        # Create mapping from original graph index to new graph index
        # Train set mapping: original graph index -> new graph index (starting from 1)
        train_graph_map = {old_idx: new_idx + 1 for new_idx, old_idx in enumerate(sorted(train_indices))}
        # Test set mapping: original graph index -> new graph index (starting from 1)
        test_graph_map = {old_idx: new_idx + 1 for new_idx, old_idx in enumerate(sorted(test_indices))}
        
        # Create node index mapping
        train_node_indices = []
        test_node_indices = []
        train_node_map = {}  # Original node index -> new node index
        test_node_map = {}   # Original node index -> new node index
        
        # Find node indices in training and test sets
        for node_idx, graph_idx in enumerate(graph_indicator):
            graph_idx = graph_idx - 1  # Adjust to 0-based index
            if graph_idx in train_indices:
                train_node_indices.append(node_idx)
            elif graph_idx in test_indices:
                test_node_indices.append(node_idx)
        # train_node_indices and test_node_indices are ascending
        
        # Create node index mapping
        for new_idx, old_idx in enumerate(train_node_indices):
            train_node_map[old_idx] = new_idx + 1  # New node indices start from 1
        
        for new_idx, old_idx in enumerate(test_node_indices):
            test_node_map[old_idx] = new_idx + 1  # New node indices start from 1
        
        # Find all original files
        file_types = ['graph_labels', 'graph_attributes', 
                      'graph_indicator', 'A', 
                      'node_attributes', 'node_labels']
        for file_type in file_types:
            src_file = osp.join(raw_path, f"{ds_name}_{file_type}.txt")
            if not osp.exists(src_file):
                continue
                
            # Read original file
            with open(src_file, 'r') as f:
                lines = f.readlines()
            
            # Process differently based on file type
            if file_type == 'graph_labels' or file_type == 'graph_attributes':
                # Split directly by graph index
                train_lines = [lines[i] for i in sorted(train_indices)]
                test_lines = [lines[i] for i in sorted(test_indices)]
            elif file_type == 'graph_indicator':
                # Remap graph indicators
                train_lines = []
                test_lines = []
                
                for i, node_idx in enumerate(train_node_indices):
                    old_graph_idx = graph_indicator[node_idx] - 1  # Adjust to 0-based index
                    new_graph_idx = train_graph_map[old_graph_idx]
                    train_lines.append(f"{new_graph_idx}\n")
                
                for i, node_idx in enumerate(test_node_indices):
                    old_graph_idx = graph_indicator[node_idx] - 1  # Adjust to 0-based index
                    new_graph_idx = test_graph_map[old_graph_idx]
                    test_lines.append(f"{new_graph_idx}\n")
            elif file_type == 'A':
                # Remap node indices for edges
                train_lines = []
                test_lines = []
                
                for line in lines:
                    src, dst = map(int, line.strip().split(','))
                    src_idx, dst_idx = src - 1, dst - 1  # Adjust to 0-based index
                    
                    if src_idx in train_node_map and dst_idx in train_node_map:
                        # Edge belongs to training set
                        new_src = train_node_map[src_idx]
                        new_dst = train_node_map[dst_idx]
                        train_lines.append(f"{new_src},{new_dst}\n")
                    elif src_idx in test_node_map and dst_idx in test_node_map:
                        # Edge belongs to test set
                        new_src = test_node_map[src_idx]
                        new_dst = test_node_map[dst_idx]
                        test_lines.append(f"{new_src},{new_dst}\n")
                    elif (
                        src_idx not in train_node_map and \
                        src_idx not in test_node_map and \
                        dst_idx not in train_node_map and \
                        dst_idx not in test_node_map
                    ):
                        continue
                    else:
                        raise Exception('wrong index!')
            else:
                # Node attributes or labels
                if file_type == 'node_labels':
                    # 处理node labels，支持单列和多列情况
                    lines_list = [line.strip().split(',') for line in lines]  # 分割每行的标签
                    
                    # 检查是否有多列标签
                    if len(lines_list[0]) == 1:
                        # 单列标签情况
                        lines_np = np.array([int(line[0]) for line in lines_list], dtype=int)
                        lines_np = lines_np - lines_np.min()
                        lines = [str(line) + '\n' for line in lines_np.tolist()]
                    else:
                        # 多列标签情况
                        lines_np = np.array(lines_list, dtype=int)  # 转换为numpy数组
                        
                        # 对每一列分别进行0-based映射
                        for col in range(lines_np.shape[1]):
                            col_min = lines_np[:, col].min()
                            lines_np[:, col] = lines_np[:, col] - col_min
                        
                        # 重新组合为字符串格式
                        lines = [','.join(map(str, line)) + '\n' for line in lines_np.tolist()]

                train_lines = [lines[idx] for idx in train_node_indices]
                test_lines = [lines[idx] for idx in test_node_indices]

                # if file_type == 'node_labels' and set(train_lines) != set(test_lines):
                #     raise ValueError(f"Train and test node labelsdo not match")
            
            # Save split files
            train_file = osp.join(train_raw_path, f"{ds_name}_train_{file_type}.txt")
            test_file = osp.join(test_raw_path, f"{ds_name}_test_{file_type}.txt")
            
            with open(train_file, 'w') as f:
                f.writelines(train_lines)
            
            with open(test_file, 'w') as f:
                f.writelines(test_lines)
        
        print(f"Dataset {ds_name} splitting completed")
    if not return_indices:
        return train_data_path, test_data_path
    else:
        return train_data_path, test_data_path, train_indices, test_indices






