import os
from pathlib import Path

from dotenv import load_dotenv
import hydra
from hydra.utils import instantiate
import json
import torch
from tqdm import tqdm

from text_ood import EmbeddingType
from text_ood.task import EmbeddingCreator
from text_ood.utils import set_seed
from text_ood.utils.dataset_util import convert_to_ragged_format


def save_text(texts, model_name, dataset_name, dataset_split, embedding_type, embedding_root=None, seed=None):
    if embedding_root is None:
        embedding_root = os.environ['EMBEDDING_ROOT']
    
    embedding_root = Path(embedding_root)
    dataset_name = dataset_name.replace('/', '_')
    model_name = model_name.replace('/', '_')
    embedding_type = embedding_type.name.lower()
    text_dir = embedding_root / model_name / dataset_name / dataset_split
    text_path = text_dir / f'ragged_{embedding_type}_text_seed={seed}.json'

    os.makedirs(text_dir, exist_ok=True)
    with open(text_path, 'w') as f:
        json.dump(texts, f, indent=4)

def save_tensor(tensor, model_name, dataset_name, dataset_split, embedding_type, embedding_root=None, seed=None):
    if embedding_root is None:
        embedding_root = os.environ['EMBEDDING_ROOT']

    embedding_root = Path(embedding_root)
    dataset_name = dataset_name.replace('/', '_')
    model_name = model_name.replace('/', '_')
    embedding_type = embedding_type.name.lower()
    tensor_dir = embedding_root / model_name / dataset_name / dataset_split
    tensor_path = tensor_dir / f'ragged_{embedding_type}_seed={seed}.pt'

    os.makedirs(tensor_dir, exist_ok=True)
    torch.save(tensor, tensor_path)


def create_embeddings_for_dataset(dataset, task, embedding_type, n_embeddings, dataset_name, dataset_split, config):
    set_seed(config.seed, f'{dataset_name}/{dataset_split}')

    if embedding_type == EmbeddingType.INPUT:
        embedding_fn = task.input_embeddings
        batch_size = config.input_batch_size
    elif embedding_type == EmbeddingType.OUTPUT:
        embedding_fn = task.output_embeddings
        batch_size = config.output_batch_size
    else:
        raise ValueError()
    loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    texts = []
    input_ids = []
    embeddings = []
    masks = []
    
    n_samples = min(n_embeddings, len(dataset))
    
    with tqdm(desc=f'Processing {dataset_name}', total=n_samples) as pbar:
    
        for batch in loader:
            text, input_id, embedding, mask = embedding_fn(batch)
            texts.extend(text)
            input_ids.append(input_id.cpu())
            pbar.update(min(len(embedding), n_samples - sum(len(embedding) for embedding in embeddings)))
            embeddings.append(embedding.bfloat16().cpu())
            masks.append(mask.cpu())
            
            if sum(len(embedding) for embedding in embeddings) > n_embeddings:
                break
    
    input_ids = torch.concat(input_ids, dim=0)[:n_embeddings]
    embeddings = torch.concat(embeddings, dim=0)[:n_embeddings]
    masks = torch.concat(masks, dim=0)[:n_embeddings]
    texts = texts[:n_embeddings]
    
    #assert len(embeddings) == n_embeddings
    #assert len(masks) == n_embeddings

    save_text(
        texts=texts,
        model_name=config.model.model_name,
        dataset_name=dataset_name,
        dataset_split=dataset_split,
        embedding_type=embedding_type,
        seed=config.seed
    )

    embeddings, input_ids, start_idxs = convert_to_ragged_format(embeddings, input_ids, masks)
    
    save_tensor(
        tensor={'embeddings': embeddings, 'input_ids': input_ids, 'start_idxs': start_idxs},
        model_name=config.model.model_name,
        dataset_name=dataset_name,
        dataset_split=dataset_split,
        embedding_type=embedding_type,
        seed=config.seed
    )



@torch.no_grad
@hydra.main(config_path='config_create_embeddings', version_base='1.2')
def main(config):
    load_dotenv()
    
    task: EmbeddingCreator = instantiate(config.task).to(config.device).eval()
    
    embedding_type: EmbeddingType = EmbeddingType[config.embedding_type]
    
    id_datasets = config.data.id
    aux_datasets = config.data.aux
    ood_datasets = config.data.ood
    
    if config.run_id:
        for dataset_conf in id_datasets:
            create_embeddings_for_dataset(
                dataset=instantiate(dataset_conf),
                task=task,
                embedding_type=embedding_type,
                n_embeddings=config.n_embeddings_id,
                dataset_name=dataset_conf.path,
                dataset_split=dataset_conf.split,
                config=config
            )
    
    if config.run_aux:   
        for dataset_conf in aux_datasets:
            create_embeddings_for_dataset(
                dataset=instantiate(dataset_conf),
                task=task,
                embedding_type=embedding_type,
                n_embeddings=config.n_embeddings_aux,
                dataset_name=dataset_conf.path,
                dataset_split=dataset_conf.split,
                config=config
            )
    
    if config.run_ood:  
        for dataset_conf in ood_datasets:
            create_embeddings_for_dataset(
                dataset=instantiate(dataset_conf),
                task=task,
                embedding_type=embedding_type,
                n_embeddings=config.n_embeddings_ood,
                dataset_name=dataset_conf.path,
                dataset_split=dataset_conf.split,
                config=config
            )


if __name__ == '__main__':
    main()
