import torch
from torch.utils.data import Dataset, IterableDataset
from typing import Dict, Any, List, Optional, Union, Type
from abc import ABC, abstractmethod
from graph_tokenizers.numerical_tokenizer import NumericalTokenizer
from omegaconf import DictConfig
from itertools import cycle
import time
import json, os

class BaseDataset(Dataset, ABC):
    def __init__(
        self,
        data: List[Dict[str, Any]],
        tokenizer: Any,
        cfg: DictConfig,
        padding: bool = True,
        truncation: bool = False
    ):
        self.tokenizer = tokenizer
        self.padding = padding
        self.truncation = truncation
        self.model_name = cfg.model.name
        self.max_length = cfg.model.max_length
        self.target_max_length = cfg.model.target_max_length
        self.autoregressive = cfg.model.autoregressive
        if isinstance(self.tokenizer, NumericalTokenizer):
            self.delimiter_id = self.tokenizer.delimiter_id
        else:
            self.delimiter_id = self.tokenizer.encode("=")[0]
        self.data = data

    @abstractmethod
    def _prepare_input(self, item: Dict[str, Any]) -> str:
        raise NotImplementedError
        
    @abstractmethod
    def _prepare_target(self, item: Dict[str, Any]) -> Optional[str]:
        raise NotImplementedError
        
    def _combine_input_target(self, input_str: str, target_str: Optional[str]) -> str:
        if target_str is None:
            return input_str
        return f"{input_str}{target_str}"
        
    def _prepare_sequence(self, item: Dict[str, Any]) -> str:
        input_str = self._prepare_input(item)
        target_str = self._prepare_target(item)
        return self._combine_input_target(input_str, target_str)
    
    def __len__(self) -> int:
        return len(self.data)
    
    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        item = self.data[idx]
        input_str = self._prepare_sequence(item)
        tokenized = self.tokenizer(
            input_str,
            max_length=self.max_length,
            padding=self.padding,
            truncation=self.truncation,
            return_tensors='pt'
        )
        outputs = self._process_tokenizer_outputs(tokenized)
        outputs['idx'] = torch.tensor(idx)
        return outputs
    
    def _process_tokenizer_outputs(
        self,
        tokenized: Dict[str, torch.Tensor]
    ) -> Dict[str, torch.Tensor]:
        input_ids = tokenized['input_ids'].squeeze(0)
        attention_mask = tokenized['attention_mask'].squeeze(0)
        processed_input_ids = input_ids[:-1].clone()
        processed_attention_mask = attention_mask[:-1].clone()
        labels = input_ids.clone()
        idx = torch.where(labels == self.delimiter_id)[0].item()
        labels[:idx+1] = -100
        labels[attention_mask == 0] = -100
        if self.autoregressive:
            labels = labels[1:]
        out = {
            'input_ids': processed_input_ids,
            'attention_mask': processed_attention_mask,
            'labels': labels
        }
        return out

class StreamingBaseDataset(IterableDataset, ABC):
    def __init__(
        self,
        tokenizer: Any,
        cfg: DictConfig,
        padding: bool = True,
        truncation: bool = False,
        **kwargs,
    ):
        self.cfg = cfg
        self.kwargs = kwargs
        self.tokenizer = tokenizer
        self.padding = padding
        self.truncation = truncation
        self.model_name = cfg.model.name
        self.max_length = cfg.model.max_length
        self.target_max_length = cfg.model.target_max_length
        self.autoregressive = cfg.model.autoregressive
        if isinstance(self.tokenizer, NumericalTokenizer):
            self.delimiter_id = self.tokenizer.delimiter_id
        else:
            self.delimiter_id = self.tokenizer.encode("=")[0]
        self.chunk_size = cfg.train.online.chunk_size
        self.chunk_idx = 0
        self.num_chunks = cfg.train.online.num_chunks
        self.generate_chunks_first = cfg.train.online.generate_chunks_first
        self.chunk_dir = os.path.join(cfg.root, 'data', cfg.wandb.run_name)
        os.makedirs(self.chunk_dir, exist_ok=True)
        if kwargs.get('data', None) is not None:
            print("Using provided data")
            assert self.num_chunks == 1, "chunk_size must be 1 if data is provided"
            assert self.generate_chunks_first, "generate_chunks_first must be True if data is provided"
            data = kwargs['data']
            chunk_file = os.path.join(self.chunk_dir, 'chunk_0.json')
            with open(chunk_file, 'w') as f:
                json.dump(data, f, indent=2)
            self.chunk_files = {0: chunk_file}
            print(f"Saved {len(data)} graphs to {chunk_file}")
        elif self.generate_chunks_first:
            self._pregenerate_all_chunks()

    @abstractmethod
    def _generate_chunk(self, chunk_idx: int) -> List[Dict[str, Any]]:
        raise NotImplementedError
    
    def _process_chunk(self, chunk: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        raise NotImplementedError

    def _load_chunk(self, chunk_idx: int) -> List[Dict[str, Any]]:
        load_start_time = time.time()
        load_idx = chunk_idx % self.num_chunks
        chunk_file = self.chunk_files[load_idx]
        with open(chunk_file, 'r') as f:
            _chunk = json.load(f)
        chunk = [{**item, 'path_nodes': {int(k): v for k, v in item['path_nodes'].items()}, 'policy_nodes': {int(k): v for k, v in item['policy_nodes'].items()}} for item in _chunk]
        load_time = time.time() - load_start_time
        print(f"Loaded chunk {load_idx} for idx {chunk_idx} in {load_time:.2f}s")
        return chunk
    
    def _pregenerate_all_chunks(self):
        print("Pre-generating all chunks...")
        self.chunk_files = {}
        for chunk_idx in range(self.num_chunks):
            chunk_start_time = time.time()
            chunk = self._generate_chunk(chunk_idx)
            filename = f"chunk_{chunk_idx}.json"
            filepath = os.path.join(self.chunk_dir, filename)
            with open(filepath, 'w') as f:
                json.dump(chunk, f, indent=2)
            self.chunk_files[chunk_idx] = filepath
            gen_time = time.time() - chunk_start_time
            print(f"Generated chunk {chunk_idx} in {gen_time:.2f}s")

    def _retrieve_chunk(self, chunk_idx: int) -> List[Dict[str, Any]]:
        if self.generate_chunks_first:
            chunk = self._load_chunk(chunk_idx)
        else:
            chunk = self._generate_chunk(chunk_idx)
        return self._process_chunk(chunk)
    
    def _prepare_input(self, item: Dict[str, Any]) -> str:
        raise NotImplementedError

    def _prepare_target(self, item: Dict[str, Any]) -> Optional[str]:
        raise NotImplementedError

    def _combine_input_target(self, input_str: str, target_str: Optional[str]) -> str:
        if target_str is None:
            return input_str
        return f"{input_str}{target_str}"

    def _prepare_sequence(self, item: Dict[str, Any]) -> str:
        input_str = self._prepare_input(item)
        target_str = self._prepare_target(item)
        return self._combine_input_target(input_str, target_str)

    def _process_tokenizer_outputs(
        self,
        tokenized: Dict[str, torch.Tensor]
    ) -> Dict[str, torch.Tensor]:
        input_ids = tokenized['input_ids'].squeeze(0)
        attention_mask = tokenized['attention_mask'].squeeze(0)
        processed_input_ids = input_ids[:-1].clone()
        processed_attention_mask = attention_mask[:-1].clone()
        labels = input_ids.clone()
        idx = torch.where(labels == self.delimiter_id)[0].item()
        labels[:idx+1] = -100
        labels[attention_mask == 0] = -100
        if self.autoregressive:
            labels = labels[1:]
        out = {
            'input_ids': processed_input_ids,
            'attention_mask': processed_attention_mask,
            'labels': labels
        }
        return out

    def __iter__(self):
        while True:
            start_time = time.time()
            chunk = self._retrieve_chunk(self.chunk_idx)
            gen_time = time.time() - start_time
            print(f"--- Retrieved chunk {self.chunk_idx} in {gen_time:.2f}s ---")
            for idx, item in enumerate(chunk):
                input_str = self._prepare_sequence(item)
                tokenized = self.tokenizer(
                    input_str,
                    max_length=self.max_length,
                    padding=self.padding,
                    truncation=self.truncation,
                    return_tensors='pt'
                )
                outputs = self._process_tokenizer_outputs(tokenized)
                outputs['idx'] = torch.tensor(idx)
                yield outputs
            self.chunk_idx += 1

def get_dataloader(
    XDataset: type,
    tokenizer: Any,
    cfg: DictConfig,
    batch_size: int,
    shuffle_loader: bool = True,
    use_cycle: bool = False,
    num_workers: int = 0,
    pin_memory: bool = True,
    **kwargs,
):
    if issubclass(XDataset, StreamingBaseDataset):
        assert 'eval_mode' in kwargs, "eval_mode must be provided for online datasets"
        dataset = XDataset(
            tokenizer=tokenizer,
            cfg=cfg,
            **kwargs,
        )
        dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_size=batch_size,
            num_workers=num_workers,
            pin_memory=pin_memory,
        )
        return dataset, iter(dataloader)
    else:
        dataset = XDataset(
            tokenizer=tokenizer,
            cfg=cfg,
            **kwargs,
        )
        dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=shuffle_loader,
            num_workers=num_workers,
            pin_memory=pin_memory,
        )
        return dataset, cycle(dataloader) if use_cycle else dataloader