import torch
from torch.utils.data import Dataset,DataLoader
# from torch.serialization import FILE_LIKE
import os
from typing import (BinaryIO, IO, Union)
FILE_LIKE = Union[str, os.PathLike, BinaryIO, IO[bytes]]

import os
from typing import Any,Literal,Optional
from warnings import warn

from .utils import generateRandomGraph,getMaxCutHamiltonian

GraphDataKeys = Literal[
    'graph',
    'hamiltonian',
    'num_edges',
    'num_max_cut',
    'max_cut_bases',
    'max_cut_value'
]
GraphData = dict[GraphDataKeys,torch.Tensor|int]

class ErdosRenyiDataset(Dataset):
    def __init__(self, 
                 file: Optional[FILE_LIKE]=None,
                 generate:bool=False,
                 n:Optional[int]=None,
                 p:Optional[float]=None,
                 num_graphs:Optional[int]=None
                 ):
        if generate:
            if n is None or p is None or num_graphs is None:
                raise ValueError("When generating a dataset, 'n', 'p', and "
                "'num_graphs' must be provided.")
            
            self.p = p
            self.numVertices = n
            self.data = ErdosRenyiDataset.makeGraphDataset(num_graphs, n, p)
            
            if file:
                self.save(file)
            else:
                warn('Dataset generated but not saved to file, call the '
                '`save(file)` method to do so.')
                
        else:
            if not file:
                raise ValueError("A file or filename must be provided to load an existing dataset.")
            
            if isinstance(file, str | os.PathLike):
                if not os.path.exists(file):
                    raise FileNotFoundError(f'Dataset file `{file}` does not exist.')
                self.data = torch.load(file, weights_only=False)
            else:
                self.data = torch.load(file, weights_only=False)
            print(f'Dataset loaded')
    
    def save(self, file:FILE_LIKE):
        if isinstance(file, str | os.PathLike):  # Save to file path
            torch.save(self.data, file)
            print(f'Dataset saved to file `{file}`')
        else:  # Save to file-like object
            torch.save(self.data, file)
            print(f'Dataset saved to requested file.')

    @staticmethod
    def makeGraphDataset(N:int, n:int, p:float)->GraphData:
        assert n > 0, f'Number of vertices must be a positive integer'
        assert p <= 1.0 and p > 0.0, f'Invalid edge probability {p:0.4f}'

        dataset = []

        for i in range(N):
            G = generateRandomGraph(n, p)
            num_edges = G.count_nonzero().item()
            H = getMaxCutHamiltonian(G)
            min_energy = H.min()
            max_cut_bases = (H == min_energy).nonzero(as_tuple=False).flatten()
            num_max_cut = len(max_cut_bases)
            max_cut_bases[num_max_cut//2:] = max_cut_bases[num_max_cut//2:].flip(0)
            data_element = {
                'graph': G,
                'hamiltonian': H,
                'num_edges': num_edges,
                'num_max_cut': num_max_cut,
                'max_cut_bases': max_cut_bases,
                'max_cut_value': int(min_energy.item())
            }
            dataset.append(data_element)
        return dataset

    def __len__(self)->int:
        return len(self.data)
    
    def __getitem__(self, idx)->GraphData:
        if torch.is_tensor(idx):
            idx = idx.tolist()
        sample = self.data[idx]
        sample['index'] = idx
        return sample


def graph_dataset_collate_fn(batch:list[dict[str,Any]])->dict[str,torch.Tensor]:
    """
    Custom collate function to handle batching of the dataset elements.
    """
    # Stack adjacency matrices and Hamiltonians as tensors
    adjacency_matrices = torch.stack([item['graph'] for item in batch])
    hamiltonians = torch.stack([item['hamiltonian'] for item in batch])
    
    # Convert integer count of max-cut solutions into a tensor
    num_edges = torch.tensor([item['num_edges'] for item in batch], dtype=torch.int32)
    num_max_cut = torch.tensor([item['num_max_cut'] for item in batch], dtype=torch.int32)
    max_cut_value = torch.tensor([item['max_cut_value'] for item in batch], dtype=torch.int32)
    
    # The max-cut basis indices are of varying lengths. Use padding to make them batchable.
    padded_max_cut_bases = torch.zeros((len(batch), num_max_cut.max()), dtype=torch.long)
    for i in range(len(batch)):
        padded_max_cut_bases[i,:num_max_cut[i]] = batch[i]['max_cut_bases']
    
    return {
        'graph': adjacency_matrices,
        'hamiltonian': hamiltonians,
        'num_edges': num_edges,
        'num_max_cut': num_max_cut,
        'max_cut_bases': padded_max_cut_bases,
        'max_cut_value': max_cut_value,
        'index': torch.tensor([item['index'] for item in batch], dtype=torch.long)
    }

def ErdosRenyiDataLoader(dataset:ErdosRenyiDataset, batch_size:int, 
                         **kwargs)->DataLoader:
    return DataLoader(dataset, batch_size, 
                      collate_fn=graph_dataset_collate_fn, **kwargs)

if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser(
        description="Generate an Erdos-Renyi graph dataset."
    )
    parser.add_argument(
        '--n', type=int, required=True,
        help="Number of vertices in the graph (must be a positive integer)."
    )
    parser.add_argument(
        '--p', type=float, required=True,
        help="Edge probability for the graph (a float between 0 and 1)."
    )
    parser.add_argument(
        '--num_graphs', type=int, default=100,
        help="Number of graphs to generate (default: 100)."
    )
    parser.add_argument(
        '--filepath', type=str, default='./data/graphs',
        help="Directory path to save the dataset (default: './data/graphs')."
    )
    parser.add_argument(
        '--file_prefix', type=str, default=None,
        help="Optional prefix to add to the dataset filename."
    )

    args = parser.parse_args()

    # Create the dataset, generating it if needed.
    dataset = ErdosRenyiDataset(
        n=args.n,
        p=args.p,
        filepath=args.filepath,
        generate=True,
        file_prefix=args.file_prefix,
        num_graphs=args.num_graphs
    )
