import os
import json
import torch
import io
import random
from functools import partial
from PIL import Image

from transformers import BatchFeature
from datasets import IterableDataset, Dataset
from datasets import disable_caching
from .sqlitedict import SqliteDict
from .prompt import get_prompt

disable_caching()


def get_dataloader(dataset_dir, dataset_names, model_name, seed, max_prompt_length, tokenizer, prefetch, split_id, batch_size=16, epochs=1, num_questions=None, is_eval=False, last_example_idx=None, is_vision=False):
    rng = random.Random(seed)

    example_ids = list()
    db_path = dict()
    for dataset_name in dataset_names:
        db_path[dataset_name] = os.path.join(dataset_dir, dataset_name, 'db', f"{dataset_name}.sqlite")
        database = SqliteDict(db_path[dataset_name], table='images', key_column='photoid')
        photoids = database.all_keys(where=f"split={split_id}")
        photoids = list(sorted(photoids))
        example_ids.extend([(dataset_name, str(photoid)) for photoid in photoids])
        database.close()

    example_ids = list(sorted(example_ids))

    def get_examples():
        databases = dict()
        for dataset_name in dataset_names:
            databases[dataset_name] = SqliteDict(db_path[dataset_name], table='images', key_column='photoid')
        value_columns = ['question', 'options', 'answer']
        if is_vision:
            value_columns.append('image')

        if not is_eval:
            rng.shuffle(example_ids)
        is_start = True
        if last_example_idx is not None:
            is_start = False
        for _ in range(epochs):
            for i, (dataset_name, photoid) in enumerate(example_ids):
                if num_questions > 0 and i == num_questions:
                    break
                example_idx = dataset_name + '_' + str(photoid)
                if example_idx == last_example_idx:
                    is_start = True
                    print('start data at i =', i)
                    continue
                if not is_start:
                    continue
                row = databases[dataset_name].get(photoid, value_columns=value_columns)
                if is_vision:
                    question, options, answer, image = row
                    image = Image.open(io.BytesIO(image))
                else:
                    image = None
                    question, options, answer = row
                options = json.loads(options)
                prompt, input_ids = get_prompt(model_name, tokenizer, question, options, answer_type='mcqa', image=image)

                example = dict()
                example['idx'] = dataset_name + '_' + str(photoid)
                example['dataset_name'] = dataset_name
                example['question'] = question
                example['options'] = options
                example['prompt'] = prompt
                example['prompt_length'] = len(input_ids)
                if example['prompt_length'] > max_prompt_length:
                    print('skip because too prompt too long')
                    continue
                example['input_ids'] = input_ids
                example['num_options'] = len(options)
                example['answer'] = answer
                if image is not None:
                    example['image'] = image

                yield example
            rng.shuffle(example_ids)

    if is_eval:
        dataset = Dataset.from_generator(get_examples)
    else:
        dataset = IterableDataset.from_generator(get_examples)

    _collate_fn = collate_fn_inference_vision if is_vision else collate_fn_inference
    dataloader = torch.utils.data.DataLoader(
        dataset,
        shuffle=False,
        pin_memory=True,
        num_workers=1,
        collate_fn=partial(_collate_fn, tokenizer=tokenizer),
        batch_size=batch_size,
        prefetch_factor=prefetch
    )
    return dataloader


def collate_fn_inference(examples, tokenizer):
    batch_size = len(examples)
    pad_token_id = tokenizer.pad_token_id

    seq_lengths = [len(example['input_ids']) for example in examples]
    max_seq_length = max(len(example['input_ids']) for example in examples)
    input_ids = torch.full((batch_size, max_seq_length), fill_value=pad_token_id, dtype=torch.int64)
    for i, example in enumerate(examples):
        input_ids[i, :len(example['input_ids'])] = torch.as_tensor(example['input_ids'], dtype=torch.int64)

    inputs = dict()
    inputs['input_ids'] = input_ids
    inputs['seq_lengths'] = torch.as_tensor(seq_lengths, dtype=torch.int32)
    inputs = BatchFeature(inputs)

    return examples, inputs


def collate_fn_inference_vision(examples, tokenizer):
    processor = tokenizer
    tokenizer = processor.tokenizer
    pad_token_id = tokenizer.pad_token_id
    assert pad_token_id == 0

    seq_lengths = [len(example['input_ids']) for example in examples]
    texts = [example['prompt'] for example in examples]
    images = [[example['image']] for example in examples]
    inputs = processor(text=texts, images=images, return_tensors="pt", padding=True, padding_side='right')
    input_ids = inputs['input_ids']
    pixel_values = inputs['pixel_values']

    inputs = dict()
    inputs['input_ids'] = input_ids
    inputs['pixel_values'] = pixel_values
    inputs['seq_lengths'] = torch.as_tensor(seq_lengths, dtype=torch.int32)
    inputs = BatchFeature(inputs)

    return examples, inputs


def collate_train(examples, tokenizer):
    for example in examples:
        prompt_input_ids = example['input_ids']
        generated_ids = example['generated_ids']
        example['input_ids'] = prompt_input_ids + generated_ids
        labels = [-100] * len(prompt_input_ids) + generated_ids
        example['labels'] = labels

        log_probs = torch.stack(example['log_probs']).flatten()
        assert len(log_probs) == len(generated_ids)
        example['log_probs'] = torch.cat([torch.as_tensor([-1.0] * len(prompt_input_ids)), log_probs])

    detector_labels = list()
    is_correct_answer = list()


    seq_lengths = [len(example['input_ids']) for example in examples]
    max_seq_length = max(len(example['input_ids']) for example in examples)
    min_prompt_length = min(example['prompt_length'] for example in examples)
    completion_mask_length = max_seq_length - min_prompt_length

    input_size = (len(examples), max_seq_length)
    input_ids = torch.full(size=input_size, fill_value=0, dtype=torch.long)
    labels = torch.full(size=input_size, fill_value=-100, dtype=torch.long)

    mask_size = (len(examples), completion_mask_length)
    completion_mask = torch.zeros(size=mask_size)
    log_probs_old = torch.full(size=mask_size, fill_value=-1.0)
    log_probs_dummy = torch.full(size=input_size, fill_value=-1.0)

    for i, example in enumerate(examples):
        input_ids[i][:len(example['input_ids'])] = torch.as_tensor(example['input_ids'], dtype=torch.int64)

        labels[i][:len(example['labels'])] = torch.as_tensor(example['labels'], dtype=torch.int64)  # not necessary with completion mask
        completion_mask[i] = labels[i][-completion_mask_length:] != -100

        log_probs_dummy[i][:len(example['log_probs'])] = torch.as_tensor(example['log_probs'])  # not necessary with completion mask
        log_probs_old[i] = log_probs_dummy[i][-completion_mask_length:]

        if example.get('detector_label') is not None:
            is_correct_answer.append(1)
            detector_labels.append(example['detector_label'])
        else:
            is_correct_answer.append(0)


    batch = dict()
    inputs = dict(input_ids=input_ids)
    batch['inputs'] = BatchFeature(inputs)
    batch['completion_mask'] = completion_mask
    batch['log_probs_old'] = log_probs_old
    batch['advantage'] = torch.as_tensor([e['advantage'] for e in examples], dtype=torch.float32)
    batch['indices'] = [e['idx'] for e in examples]
    batch['seq_lengths'] = torch.as_tensor(seq_lengths)
    batch['detector_labels'] = torch.as_tensor(detector_labels, dtype=torch.float32)
    batch['is_correct_answer'] = torch.as_tensor(is_correct_answer, dtype=torch.bool)

    return batch


def collate_train_vision(examples, tokenizer):
    for example in examples:
        prompt_input_ids = example['input_ids']
        generated_ids = example['generated_ids']
        example['input_ids'] = prompt_input_ids + generated_ids
        labels = [-100] * len(prompt_input_ids) + generated_ids
        example['labels'] = labels

        log_probs = torch.stack(example['log_probs']).flatten()
        assert len(log_probs) == len(generated_ids)
        example['log_probs'] = torch.cat([torch.as_tensor([-1.0] * len(prompt_input_ids)), log_probs])


    detector_labels = list()
    is_correct_answer = list()

    seq_lengths = [len(example['input_ids']) for example in examples]
    max_seq_length = max(len(example['input_ids']) for example in examples)
    min_prompt_length = min(example['prompt_length'] for example in examples)
    completion_mask_length = max_seq_length - min_prompt_length

    input_size = (len(examples), max_seq_length)
    input_size_embeds = (len(examples), max_seq_length, examples[0]['inputs_embeds'].shape[-1])
    input_ids = torch.full(size=input_size, fill_value=0, dtype=torch.long)
    labels = torch.full(size=input_size, fill_value=-100, dtype=torch.long)
    inputs_embeds = torch.full(size=input_size_embeds, fill_value=0.0, dtype=torch.bfloat16)

    mask_size = (len(examples), completion_mask_length)
    completion_mask = torch.zeros(size=mask_size)
    log_probs_old = torch.full(size=mask_size, fill_value=-1.0)
    log_probs_dummy = torch.full(size=input_size, fill_value=-1.0)

    for i, example in enumerate(examples):
        seq_length = len(example['input_ids'])
        input_ids[i][:seq_length] = torch.as_tensor(example['input_ids'], dtype=torch.int64)
        inputs_embeds[i][:len(example['inputs_embeds'])] = example['inputs_embeds']

        labels[i][:len(example['labels'])] = torch.as_tensor(example['labels'], dtype=torch.int64)  # not necessary with completion mask
        completion_mask[i] = labels[i][-completion_mask_length:] != -100

        log_probs_dummy[i][:len(example['log_probs'])] = torch.as_tensor(example['log_probs'])  # not necessary with completion mask
        log_probs_old[i] = log_probs_dummy[i][-completion_mask_length:]

        if example.get('detector_label') is not None:
            is_correct_answer.append(1)
            detector_labels.append(example['detector_label'])
        else:
            is_correct_answer.append(0)


    batch = dict()
    inputs = dict(input_ids=input_ids, image_inputs_embeds=inputs_embeds)
    batch['inputs'] = BatchFeature(inputs)
    batch['completion_mask'] = completion_mask
    batch['log_probs_old'] = log_probs_old
    batch['advantage'] = torch.as_tensor([e['advantage'] for e in examples], dtype=torch.float32)
    batch['indices'] = [e['idx'] for e in examples]
    batch['seq_lengths'] = torch.as_tensor(seq_lengths)
    batch['detector_labels'] = torch.as_tensor(detector_labels, dtype=torch.float32)
    batch['is_correct_answer'] = torch.as_tensor(is_correct_answer, dtype=torch.bool)

    return batch