import json
from logging import getLogger
import numpy as np
import os
from pathlib import Path
import shutil
import time

from datasets import load_dataset
import torch
from tqdm import tqdm


logger = getLogger(__name__)


class CollectionDataset(torch.utils.data.Dataset):
    def __init__(self, *args):
        self.collections = list(args)
        self.length = len(self.collections[0])
        # if not all(len(collection) == self.length for collection in self.collections):
        #    raise ValueError('All collections need to have the same length.')
    
    def __getitem__(self, i):
        return tuple(collection[i] for collection in self.collections)
    
    def __len__(self):
        return self.length


class RaggedDataset(torch.utils.data.Dataset):
    def __init__(self, embeddings, input_ids, start_idxs, texts):
        self.embeddings = embeddings
        self.input_ids = input_ids
        self.start_idxs = start_idxs
        self.texts = texts
        sequence_lengths = [(self.start_idxs[i + 1] - self.start_idxs[i]).item() for i in range(len(self.start_idxs) - 1)]
        sequence_lengths.append((len(self.embeddings) - self.start_idxs[-1]).item())
        self.sequence_length = max(sequence_lengths)
        padding_size = self.sequence_length - sequence_lengths[-1]
        if padding_size > 0:
            self.embeddings = torch.cat([self.embeddings, torch.zeros(padding_size, self.embeddings.shape[1], dtype=self.embeddings.dtype)], dim=0)
        self.length = len(self.start_idxs)
    
    def __getitem__(self, i):
        start_idx = self.start_idxs[i].item()
        end_idx = self.start_idxs[i + 1].item() if i + 1 < self.length else len(self.input_ids)
        mask = torch.zeros(self.sequence_length, dtype=torch.bool)
        mask[:end_idx - start_idx] = 1
        embeddings = self.embeddings[start_idx:start_idx + self.sequence_length]
        assert embeddings.shape[0] == self.sequence_length
        input_ids = torch.zeros(self.sequence_length, dtype=self.input_ids.dtype)
        input_ids[:end_idx - start_idx] = self.input_ids[start_idx:end_idx]
        return embeddings, input_ids, mask, self.texts[i]
    
    def __len__(self):
        return self.length


def convert_to_ragged_format(embeddings, input_ids, masks):
    B, S, F = embeddings.shape
    embeddings_list = []
    input_ids_list = []
    start_idxs = []
    total_length = 0
    for i in range(B):
        start_idxs.append(total_length)
        new_embeddings = embeddings[i][masks[i].bool()]
        embeddings_list.append(new_embeddings)
        total_length += len(new_embeddings)
        input_ids_list.append(input_ids[i][masks[i].bool()])
    return torch.concat(embeddings_list, dim=0), torch.concat(input_ids_list, dim=0), torch.tensor(start_idxs)


def convert_from_ragged_format(embeddings, input_ids, start_idxs):
    B = len(start_idxs)
    sequence_lenghts = [(start_idxs[i+1]-start_idxs[i]).item() for i in range(B-1)]
    sequence_lenghts.append((len(embeddings)-start_idxs[-1]).item())
    S = max(sequence_lenghts)
    F = embeddings.shape[1]
    embeddings_out = torch.zeros((B, S, F), dtype=embeddings.dtype)
    input_ids_out = torch.zeros((B, S), dtype=input_ids.dtype)
    masks_out = torch.zeros((B, S), dtype=torch.bool)
    for i in tqdm(range(B)):
        start_idx = start_idxs[i].item()
        end_idx = start_idxs[i + 1].item() if i + 1 < B else len(embeddings)
        length = end_idx - start_idx
        embeddings_out[i, :length] = embeddings[start_idx:end_idx]
        input_ids_out[i, :length] = input_ids[start_idx:end_idx]
        masks_out[i, :length] = 1
    return embeddings_out, input_ids_out, masks_out


def load_dataset_preprocess(*args, **kwargs):
    if 'remove' in kwargs:
        removes = kwargs['remove']
        del kwargs['remove']
    else:
        removes = None
    if 'rename' in kwargs:
        renames = kwargs['rename']
        del kwargs['rename']
    else:
        renames = None
    dataset = load_dataset(*args, **kwargs)
    if removes:
        dataset = dataset.remove_columns(removes)
    if renames:
        for rename in renames:
            dataset = dataset.rename_column(**rename)
    return dataset


def load_dataset_embeddings(dataset, embedding_type, config, exclude_indices=None, n_data=None):
    ragged_dataset_path = Path(dataset) / f'ragged_{embedding_type.name.lower()}_seed={config.seed}.pt'
    ragged_text_path = Path(dataset) / f'ragged_{embedding_type.name.lower()}_text_seed={config.seed}.json'
    if ragged_dataset_path.exists() and ragged_text_path.exists():
        return load_dataset_embeddings_ragged(dataset, embedding_type, config, exclude_indices, n_data)
    else:
        logger.info(f'Ragged dataset {ragged_dataset_path} not found, falling back to {embedding_type.name.lower()}_seed={config.seed}.pt')
        return load_dataset_embeddings_full(dataset, embedding_type, config, exclude_indices, n_data)


def load_dataset_embeddings_full(dataset, embedding_type, config, exclude_indices=None, n_data=None):
    dataset_path = Path(dataset) / f'{embedding_type.name.lower()}_seed={config.seed}.pt'
    text_path = Path(dataset) / f'{embedding_type.name.lower()}_text_seed={config.seed}.json'
    cache_path = os.environ['LOCAL_CACHE_ROOT']
    if cache_path:
        cache_path = Path(cache_path)
        dataset_cache_path = cache_path / str(dataset_path.resolve())[1:]
        os.makedirs(dataset_cache_path.parent, exist_ok=True)
        in_progress_file = Path(str(dataset_cache_path) + '.inprogress')
        while in_progress_file.exists():
            time.sleep(5)
            # logger.info(f'File copy from {dataset_path} to {dataset_cache_path} in progress on another process.')
        if not dataset_cache_path.exists():
            in_progress_file.touch()
            logger.info(f'Copying {dataset_path} to {dataset_cache_path}')
            shutil.copyfile(dataset_path, dataset_cache_path)
            logger.info(f'Finished copying {dataset_path} to {dataset_cache_path}')
            os.remove(in_progress_file)
        else:
            logger.info(f'Found cache file {dataset_cache_path}')
    else:
        dataset_cache_path = dataset_path

    logger.info(f'Loading dataset embeddings {dataset_cache_path}')
    dataset = torch.load(dataset_cache_path, weights_only=True)
    # texts don't have to be cached because of their very small size
    with open(text_path) as f:
        texts = json.load(f)
    embeddings = dataset['embeddings']
    masks = dataset['masks']
    dataset = CollectionDataset(embeddings.share_memory_(), dataset['input_ids'].share_memory_(), masks.share_memory_(), texts)
    if exclude_indices:
        subset_idxs = [i for i in range(len(dataset)) if i not in exclude_indices]
        dataset = torch.utils.data.Subset(dataset, subset_idxs)
    if n_data:
        if n_data > len(dataset):
            raise ValueError(f'n_data {n_data} is larger than dataset size {len(dataset)}')
        subset_idxs = np.random.choice(len(dataset), n_data, replace=False)
        dataset = torch.utils.data.Subset(dataset, subset_idxs)
    return dataset


def load_dataset_embeddings_ragged(dataset, embedding_type, config, exclude_indices=None, n_data=None):
    dataset_path = Path(dataset) / f'ragged_{embedding_type.name.lower()}_seed={config.seed}.pt'
    text_path = Path(dataset) / f'ragged_{embedding_type.name.lower()}_text_seed={config.seed}.json'
    cache_path = os.environ['LOCAL_CACHE_ROOT']
    if cache_path:
        cache_path = Path(cache_path)
        dataset_cache_path = cache_path / str(dataset_path.resolve())[1:]
        os.makedirs(dataset_cache_path.parent, exist_ok=True)
        in_progress_file = Path(str(dataset_cache_path) + '.inprogress')
        while in_progress_file.exists():
            time.sleep(5)
            # logger.info(f'File copy from {dataset_path} to {dataset_cache_path} in progress on another process.')
        if not dataset_cache_path.exists():
            in_progress_file.touch()
            logger.info(f'Copying {dataset_path} to {dataset_cache_path}')
            shutil.copyfile(dataset_path, dataset_cache_path)
            logger.info(f'Finished copying {dataset_path} to {dataset_cache_path}')
            os.remove(in_progress_file)
        else:
            logger.info(f'Found cache file {dataset_cache_path}')
    else:
        dataset_cache_path = dataset_path

    logger.info(f'Loading dataset embeddings {dataset_cache_path}')
    dataset = torch.load(dataset_cache_path, weights_only=True)
    # texts don't have to be cached because of their very small size
    with open(text_path) as f:
        texts = json.load(f)
    embeddings = dataset['embeddings']
    input_ids = dataset['input_ids']
    start_idxs = dataset['start_idxs']
    # embeddings, input_ids, masks = convert_from_ragged_format(embeddings, input_ids, start_idxs)
    dataset = RaggedDataset(embeddings.share_memory_(), input_ids.share_memory_(), start_idxs.share_memory_(), texts)
    if exclude_indices:
        subset_idxs = [i for i in range(len(dataset)) if i not in exclude_indices]
        dataset = torch.utils.data.Subset(dataset, subset_idxs)
    if n_data:
        if n_data > len(dataset):
            raise ValueError(f'n_data {n_data} is larger than dataset size {len(dataset)}')
        subset_idxs = np.random.choice(len(dataset), n_data, replace=False)
        dataset = torch.utils.data.Subset(dataset, subset_idxs)
    return dataset
