import torch
from typing import Dict, Any, List
import random
from .base import BaseDataset, StreamingBaseDataset
from omegaconf import DictConfig
from generate import required_params, mp_generate_data
from itertools import product
from functools import partial
from graphs.combo_graph import shuffle_edges

def path_to_str(path):
    return ','.join(str(node) for node in path)

def edge_to_str(edge_list, source, goal):
    edge_str = '|'.join(f'{e[0]},{e[1]}' for e in edge_list)
    edge_str += '/' + str(source) + ',' + str(goal) + '='
    return edge_str

def make_seed(*parts, mod=2**32):
    base_seed = parts[0]
    h = 0
    for p in parts:
        h = (h * (1_000_000 + base_seed - 1335) + int(p)) & (mod - 1)
    return h

class GraphDataset(BaseDataset):
    def __init__(
        self,
        data: List[Dict[str, Any]],
        tokenizer: Any,
        cfg: DictConfig,
        base_seed: int,
        padding: bool = True,
        truncation: bool = True,
        eval_mode: bool = True,
        **kwargs,
    ):
        print(f"Initializing GraphDataset with eval_mode={eval_mode} and base_seed={base_seed}")
        if 'num_paths' not in kwargs:
            raise ValueError("num_paths must be passed to the dataset")
        rng_seed = make_seed(base_seed)
        self.rng = random.Random(rng_seed)
        print(f"Initializing GraphDataset with rng_seed={rng_seed}, eval_mode={eval_mode}, base_seed={base_seed}")
        super().__init__(
            data=preprocess_dataset(data, kwargs['num_paths'], self.rng),
            tokenizer=tokenizer,
            cfg=cfg,
            padding=padding,
            truncation=truncation
        )
        self.eval_mode = eval_mode

    def _prepare_input(self, item: Dict[str, Any]) -> str:
        edge_list = item['edge_list'].copy()
        if not self.eval_mode:
            edge_list = shuffle_edges(item['layer_edges'], item['edge_shuffle_rule'], self.rng)
        return edge_to_str(edge_list, item['source'], item['goal'])

    def _prepare_target(self, item: Dict[str, Any]) -> str:
        return path_to_str(item['paths'])

def preprocess_dataset(_data: List[Dict[str, Any]], num_paths: int, rng: random.Random) -> List[Dict[str, Any]]:
    data = []
    for prompt_idx, item in enumerate(_data):
        path_nodes = [item['policy_nodes'][i] for i in range(item['path_length'])]
        paths = list(product(*path_nodes))
        paths = list(map(list, paths))
        data.extend([{**item, 'paths': path, 'prompt_idx': prompt_idx} for path in paths])
    return data

class StreamingGraphDataset(StreamingBaseDataset):
    def __init__(
        self,
        tokenizer: Any,
        cfg: DictConfig,
        graph_type: str,
        eval_mode: bool,
        base_seed: int,
        padding: bool = True,
        truncation: bool = True,
        **kwargs,
    ):
        self.graph_type = graph_type
        self.eval_mode = eval_mode
        self.shuffle_edges_each_load = cfg.train.online.shuffle_input and cfg.train.online.generate_chunks_first
        self.sample_responses_each_load = cfg.train.online.sample_response and cfg.train.online.generate_chunks_first
        self.num_workers = cfg.train.num_workers
        self.train_base_seed = base_seed
        rng_seed = make_seed(base_seed)
        self.rng = random.Random(rng_seed)
        print(f"Initializing StreamingGraphDataset with rng_seed={rng_seed}, eval_mode={eval_mode}, base_seed={base_seed}")
        super().__init__(
            tokenizer=tokenizer,
            cfg=cfg,
            padding=padding,
            truncation=truncation,
            **kwargs,
        )
    
    def _generate_chunk(self, chunk_idx: int) -> List[Dict[str, Any]]:
        print(f"Generating chunk {chunk_idx} with base {self.train_base_seed}")
        chunk = mp_generate_data(self.graph_type, self.chunk_size, self.cfg.data, self.train_base_seed, chunk_idx, num_workers=self.num_workers)
        return chunk
    
    def _process_chunk(self, chunk: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        self.rng.shuffle(chunk)
        print(f"Processing chunk; shuffling edges={self.shuffle_edges_each_load} and sampling responses={self.sample_responses_each_load}")
        if not self.eval_mode:
            for data in chunk:
                if self.sample_responses_each_load:
                    path = [self.rng.sample(data['policy_nodes'][i], 1)[0] for i in range(data['path_length'])]
                    data['paths'] = path
                if self.shuffle_edges_each_load:
                    data['edge_list'] = shuffle_edges(data['layer_edges'], data['edge_shuffle_rule'], self.rng)
        else:
            raise NotImplementedError("Eval mode Not implemented")
        return chunk

    def _prepare_input(self, item: Dict[str, Any]) -> str:
        edge_list = item['edge_list'].copy()
        return edge_to_str(edge_list, item['source'], item['goal'])

    def _prepare_target(self, item: Dict[str, Any]) -> str:
        return path_to_str(item['paths'])