from collections import defaultdict
from typing import List
import numpy as np
from torch_geometric.data import Data, Dataset
from torch_geometric.data import DataLoader as GnnDataloader
from torch_geometric.transforms import BaseTransform
from torch_geometric.datasets import ZINC

# from torch_geometric.loader import DataLoader
from torch.utils.data import DataLoader
from torch_geometric.data import Batch  
from torch_geometric.data.collate import collate
import torch
import lightning as pl
import networkx as nx
from pathlib import Path

import numpy as np


def caterpillar_collate(batch: List[Data], debug=False) -> Data:
    """extends collate to handle the edgelists and features"""

    layered_x = None
    layered_edge_index = None
    layered_edge_attr = None
    other_data = []

    for data in batch:
        if layered_x is None:
            layered_x = [[] for _ in range(data.num_layers)]
            layered_edge_index = [[] for _ in range(data.num_layers-1)]
            layered_edge_attr = [[] for _ in range(data.num_layers-1)]

        assert len(data.x_sizes) == len(data.edge_sizes) + 1
        x_sizes = data.x_sizes.tolist()
        for layer_idx, (x_begin, x_end) in enumerate(zip(x_sizes, x_sizes[1:])):
            layered_x[layer_idx].append(data.x[x_begin:x_end])

        edge_sizes = data.edge_sizes.tolist()
        for layer_idx, (edge_begin, edge_end) in enumerate(zip(edge_sizes, edge_sizes[1:])):
            layered_edge_index[layer_idx].append(data.edge_index[:, edge_begin:edge_end])
            layered_edge_attr[layer_idx].append(data.edge_attr[edge_begin:edge_end])


        del data.x
        del data.edge_index
        del data.edge_attr
        del data.x_sizes
        del data.edge_sizes
        other_data.append(data)

    layered_collated_edge_index = []
    layered_collated_edge_attr = []
    for layer_idx, (edge_index, edge_attr) in enumerate(zip(layered_edge_index, layered_edge_attr)):
        # we need increment the edge_index
        shift_source = 0
        shift_target = 0
        for bi in range(len(edge_index)):
            # source edge_index[0, :] in layered_x[layer_idx][bi]
            # target edge_index[1, :] in  layered_x[layer_idx+1][bi]
            edge_index[bi][1, :] += shift_source
            edge_index[bi][0, :] += shift_target
            shift_source += len(layered_x[layer_idx][bi])
            shift_target += len(layered_x[layer_idx+1][bi])

        layered_collated_edge_index.append(
            torch.cat(edge_index, dim=1)
        )
        layered_collated_edge_attr.append(
            torch.cat(edge_attr, dim=0)
        )

    layered_collated_x = []
    for layer_idx, x in enumerate(layered_x):
        layered_collated_x.append(torch.cat(x, dim=0))


    collated_other_data = Batch.from_data_list(other_data)

    return Data(
        x=layered_collated_x,
        edge_index=layered_collated_edge_index,
        edge_attr=layered_collated_edge_attr,
        # add DataBatch  collated_other_data
        **dict(collated_other_data.items())
    )



class DataModuleTree(pl.LightningDataModule):
    def __init__(self, batch_size=1, num_workers=0):
        super().__init__()

        self.batch_size = batch_size
        
        self.train_dataset = ZINC("data/ZINC", subset=True, split="train")
        self.val_dataset = ZINC("data/ZINC", subset=True, split="val")
        self.test_dataset = ZINC("data/ZINC", subset=True, split="test")

        self.generic_settings = dict(
            batch_size=batch_size,
            num_workers=num_workers,
        )

    def train_dataloader(self):
        return GnnDataloader(
            self.train_dataset,
            **self.generic_settings,
            shuffle=True,
        )
    
    def val_dataloader(self):
        return GnnDataloader(
            self.val_dataset,
            **self.generic_settings,
        )
    
    def test_dataloader(self):
        return GnnDataloader(
            self.test_dataset,
            **self.generic_settings,
        )



class DatasetFromProcessed(Dataset):
    def __init__(self, processed_dir=None, root=None, split=None):
        super().__init__()

        # dataset = ZINC("data/ZINC", )
        # processed_folder = Path(dataset.processed_dir)
        # processed_folder_caterpillar = processed_folder.parent / f"processed_caterpillar"
        # processed_folder_caterpillar.mkdir(parents=True, exist_ok=True)
        # split_str = split if isinstance(split, str) else "data"
        # path = processed_folder_caterpillar / f"{split_str}.pt"

        self.data = torch.load(processed_dir, weights_only=False)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


class DataModulePath(pl.LightningDataModule):

    def __init__(self, dataset_name, subname, height, batch_size=1, num_workers=0, mode=[], prepare_splits=True):
        super().__init__()

        # # # UsedDataset = CachedTransformDataset
        # UsedDataset = CaterpillarDataset


        root = Path('data')
        root = root / dataset_name
        if subname is not None:
            root = root / subname

        suffix = ""
        for m in mode:
            suffix += f"_{m}"

        processed_dir = root / f"processed-H{height}{suffix}"
        assert processed_dir.exists(), f"Directory {processed_dir} does not exist"

        if prepare_splits:
            self.train_dataset = DatasetFromProcessed(
                processed_dir=str(processed_dir / "train.pt"),
            )
            self.val_dataset = DatasetFromProcessed(
                processed_dir=str(processed_dir / "val.pt"),
            )
            self.test_dataset = DatasetFromProcessed(
                processed_dir=str(processed_dir / "test.pt"),
            )
            self.dataset = None
        else:
            self.train_dataset = None
            self.val_dataset = None
            self.test_dataset = None
            self.dataset = DatasetFromProcessed(
                processed_dir=str(processed_dir / "data.pt"),
            )
        

        self.loader_settings = dict(
            batch_size=batch_size,
            num_workers=num_workers,
            collate_fn=caterpillar_collate,
        )

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            **self.loader_settings,
            shuffle=True,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset, 
            **self.loader_settings,
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset, 
            **self.loader_settings,
        )
