import os
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torch_geometric.data import Batch

from .preprocessing import (
    preprocess_netlist_data,
    preprocess_layout_data,
    construct_cell_hypergraph,
    construct_grid_hypergraph,
    map_congestion_to_cells
)


class ISPD2015Dataset(Dataset):
    """
    Dataset class for ISPD2015 benchmark.

    Args:
        data_dir (str): Directory containing ISPD2015 dataset
        split (str): Data split ('train', 'val', or 'test')
        split_variant (str): Dataset variant ('ISPD2015-B' or 'ISPD2015-F')
        grid_size (tuple): Size of the grid (M, N)
        transform (callable, optional): Optional transform to be applied on a sample
    """

    def __init__(self, data_dir, split='train', split_variant='ISPD2015-B',
                 grid_size=(64, 64), transform=None):
        self.data_dir = data_dir
        self.split = split
        self.grid_size = grid_size
        self.transform = transform

        # Get list of designs based on the variant
        self.designs = self._get_designs(split_variant)

        # Split designs into train/val/test
        self.design_splits = self._split_designs()

        # Get designs for the current split
        self.current_designs = self.design_splits[split]

        # Preprocess data
        self.samples = self._preprocess_data()

    def _get_designs(self, split_variant):
        """Get list of designs based on the variant."""
        if split_variant == 'ISPD2015-B':
            # Exclude superblue designs
            return [d for d in os.listdir(self.data_dir)
                    if os.path.isdir(os.path.join(self.data_dir, d)) and 'superblue' not in d]
        elif split_variant == 'ISPD2015-F':
            # Include all designs
            return [d for d in os.listdir(self.data_dir)
                    if os.path.isdir(os.path.join(self.data_dir, d))]
        else:
            raise ValueError(f"Unknown split variant: {split_variant}")

    def _split_designs(self):
        """Split designs into train/val/test."""
        # Shuffle designs
        np.random.seed(42)  # For reproducibility
        designs = np.array(self.designs)
        np.random.shuffle(designs)

        # Split designs
        n = len(designs)
        train_idx = int(0.7 * n)
        val_idx = int(0.8 * n)

        return {
            'train': designs[:train_idx].tolist(),
            'val': designs[train_idx:val_idx].tolist(),
            'test': designs[val_idx:].tolist()
        }

    def _preprocess_data(self):
        """Preprocess all designs for the current split."""
        samples = []

        for design in self.current_designs:
            design_dir = os.path.join(self.data_dir, design)

            # Find all placement configurations for the design
            placement_dirs = [d for d in os.listdir(design_dir)
                              if os.path.isdir(os.path.join(design_dir, d))]

            for placement in placement_dirs:
                placement_dir = os.path.join(design_dir, placement)

                # Path to netlist and layout files
                netlist_path = os.path.join(placement_dir, 'netlist.txt')
                layout_path = os.path.join(placement_dir, 'layout.csv')

                if os.path.exists(netlist_path) and os.path.exists(layout_path):
                    # Process netlist and layout data
                    netlist_data = preprocess_netlist_data(netlist_path)
                    layout_data = preprocess_layout_data(layout_path, self.grid_size)

                    # Construct hypergraphs
                    cell_hypergraph = construct_cell_hypergraph(netlist_data)
                    grid_hypergraph = construct_grid_hypergraph(layout_data, netlist_data, self.grid_size)

                    # Map congestion from grid to cells
                    cell_congestion = map_congestion_to_cells(netlist_data, layout_data, self.grid_size)
                    grid_congestion = layout_data['grid_congestion_gt']

                    # Create sample
                    sample = {
                        'cell_hypergraph': cell_hypergraph,
                        'grid_hypergraph': grid_hypergraph,
                        'cell_congestion': torch.FloatTensor(cell_congestion),
                        'grid_congestion': torch.FloatTensor(grid_congestion),
                        'design_name': design,
                        'placement_name': placement
                    }

                    samples.append(sample)

        return samples

    def __len__(self):
        """Return the number of samples."""
        return len(self.samples)

    def __getitem__(self, idx):
        """Return a sample."""
        sample = self.samples[idx]

        if self.transform:
            sample = self.transform(sample)

        return sample


def collate_fn(batch):
    """
    Custom collate function for batching heterogeneous hypergraphs using PyTorch Geometric.

    Args:
        batch (list): List of samples

    Returns:
        dict: Batched data
    """
    # Batch cell hypergraphs
    cell_hypergraphs = [sample['cell_hypergraph'] for sample in batch]
    batched_cell_hypergraph = Batch.from_data_list(cell_hypergraphs)

    # Batch grid hypergraphs
    grid_hypergraphs = [sample['grid_hypergraph'] for sample in batch]
    batched_grid_hypergraph = Batch.from_data_list(grid_hypergraphs)

    # Concatenate cell congestion
    cell_congestion = torch.cat([sample['cell_congestion'] for sample in batch])

    # Concatenate grid congestion
    grid_congestion = torch.cat([sample['grid_congestion'] for sample in batch])

    # Collect metadata
    design_names = [sample['design_name'] for sample in batch]
    placement_names = [sample['placement_name'] for sample in batch]

    return {
        'cell_hypergraph': batched_cell_hypergraph,
        'grid_hypergraph': batched_grid_hypergraph,
        'cell_congestion': cell_congestion,
        'grid_congestion': grid_congestion,
        'design_names': design_names,
        'placement_names': placement_names
    }


def get_dataloaders(config):
    """
    Create dataloaders for train, val, and test sets.

    Args:
        config (dict): Configuration dictionary

    Returns:
        dict: Dictionary containing train, val, and test dataloaders
    """
    data_dir = config['dataset']['data_dir']
    split_variant = config['dataset']['split_variant']
    batch_size = config['dataset']['batch_size']
    num_workers = config['dataset']['num_workers']
    grid_size = (64, 64)  # Default grid size

    # Create datasets
    train_dataset = ISPD2015Dataset(data_dir, 'train', split_variant, grid_size)
    val_dataset = ISPD2015Dataset(data_dir, 'val', split_variant, grid_size)
    test_dataset = ISPD2015Dataset(data_dir, 'test', split_variant, grid_size)

    # Create dataloaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        collate_fn=collate_fn
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        collate_fn=collate_fn
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        collate_fn=collate_fn
    )

    return {
        'train': train_loader,
        'val': val_loader,
        'test': test_loader
    }