import copy
import tqdm
import time
import torch
import gc

from transformers import StaticCache


def prefill(model, batches, is_vision, return_inputs_embeds):
    tprefill = time.time()

    all_past_key_values = list()
    output_examples = dict()

    for j, batch in enumerate(batches):
        input_examples, inputs = batch
        max_prompt_length = inputs['input_ids'].shape[1]
        num_examples = len(input_examples)

        # cache_shape = (self.batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
        past_key_values = StaticCache(config=model.config.text_config if is_vision else model.config,
                                      max_batch_size=num_examples,
                                      max_cache_len=max_prompt_length,
                                      device=model.device,
                                      dtype=model.dtype)

        inputs['position_ids'] = torch.arange(max_prompt_length).unsqueeze(0).expand(num_examples, -1)
        inputs['cache_position'] = torch.zeros(size=(num_examples,), dtype=torch.int32)
        seq_lengths = inputs.pop('seq_lengths')

        inputs = copy.copy(inputs).to(model.device)

        output = model(**inputs, past_key_values=past_key_values, use_cache=False, logits_to_keep=-1)
        if return_inputs_embeds:
            inputs_embeds = output.inputs_embeds.detach().to('cpu')
        del output

        cache_position = seq_lengths.to(model.device).clone() - 1
        position_ids = cache_position.clone().unsqueeze(-1)
        input_ids = inputs['input_ids']
        del inputs

        past_key_values_id = len(all_past_key_values)
        for layer_id in range(len(past_key_values.key_cache)):
            past_key_values.key_cache[layer_id] = past_key_values.key_cache[layer_id].to('cpu')
            past_key_values.value_cache[layer_id] = past_key_values.value_cache[layer_id].to('cpu')
        all_past_key_values.append(past_key_values)

        for i, example in enumerate(input_examples):
            cache_row_id = i
            seq_length = seq_lengths[i]

            inputs = dict()
            inputs['cache_position'] = cache_position[i].item()

            if return_inputs_embeds:
                example['inputs_embeds'] = inputs_embeds[i][:seq_length]

            inputs['position_ids'] = position_ids[i].item()
            inputs['input_ids'] = input_ids[i, seq_length-1].item()

            example['inputs'] = inputs
            example['past_key_values_id'] = past_key_values_id
            example['cache_row_id'] = cache_row_id
            output_examples[(past_key_values_id, cache_row_id)] = example
        gc.collect()
        torch.cuda.empty_cache()

    sprefill = time.time() - tprefill
    return output_examples, all_past_key_values, sprefill


def set_cache_row(past_key_values, row_id, other_past_key_values, other_row_id):
    num_layers = len(past_key_values.key_cache)
    row_length = other_past_key_values.key_cache[0].shape[2]
    for layer_id in range(num_layers):
        past_key_values.key_cache[layer_id][row_id, :, :row_length] = other_past_key_values.key_cache[layer_id][other_row_id]
        past_key_values.value_cache[layer_id][row_id, :, :row_length] = other_past_key_values.value_cache[layer_id][other_row_id]
    torch.cuda.synchronize() # just in case

def generate(model, example_batches, max_new_tokens, tokenizer, temperature=1.0, num_samples=8, gen_batch_size=32, is_vision=False, pbar=False, return_inputs_embeds=False):
    if is_vision:
        tokenizer = tokenizer.tokenizer
    device = model.device
    eos_token_id = tokenizer.eos_token_id
    max_prompt_length = max(example_batch[1]['input_ids'].shape[-1] for example_batch in example_batches)
    cache_length = max_prompt_length + max_new_tokens + 1
    assert max_prompt_length + max_new_tokens <= cache_length, (max_prompt_length, max_new_tokens)

    prefilled_examples, all_past_key_values, sprefill = prefill(model, example_batches, is_vision, return_inputs_embeds)

    #add_cache_length = max_new_tokens + 1
    past_key_values = StaticCache(config=model.config.text_config if is_vision else model.config,
                                  max_batch_size=gen_batch_size,
                                  max_cache_len=cache_length,
                                  device=model.device,
                                  dtype=model.dtype)

    def get_example(proto_example):
        example_copy = copy.copy(proto_example)
        del example_copy['num_samples']
        example_copy['generated_ids'] = list()
        return example_copy

    inputs = dict()
    inputs['input_ids'] = torch.ones(size=(gen_batch_size, 1), dtype=torch.int64, device=device)
    inputs['cache_position'] = torch.zeros(size=(gen_batch_size,), dtype=torch.int32, device=device)
    inputs['position_ids'] = torch.zeros(size=(gen_batch_size, 1), dtype=torch.int64, device=device)
    current_examples = [None] * gen_batch_size

    output_examples = dict()
    for example in prefilled_examples.values():
        example_idx = example['idx']
        example['num_samples'] = num_samples
        output_examples[example_idx] = list()

    if pbar:
        pbar = tqdm.tqdm(total=len(prefilled_examples), desc="Generation", unit="Example")

    def finish_example(example):
        example_idx = example['idx']
        output_examples[example_idx].append(example)
        if len(output_examples[example_idx]) == num_samples:
            if pbar:
                pbar.update()

        if all(len(v) >= num_samples for v in output_examples.values()):
            return True
        return False

    def activate_example(current_id):
        if sum(e['num_samples'] for e in prefilled_examples.values()) <= 0: # no examples left to process
            current_examples[current_id] = None
            inputs['cache_position'][current_id] = 0
            inputs['position_ids'][current_id, 0] = 0
            return

        past_key_values_id, past_row_id = current_cache_mapping[current_id]
        if past_row_id is None or prefilled_examples[(past_key_values_id, past_row_id)]['num_samples'] <= 0: # no example for current cache_row/example_idx left
            new_past_key_values_id, new_past_row_id = max(list(prefilled_examples.keys()), key=lambda k: prefilled_examples[k]['num_samples'])
            set_cache_row(past_key_values=past_key_values,
                          row_id=current_id,
                          other_past_key_values=all_past_key_values[new_past_key_values_id],
                          other_row_id=new_past_row_id)
            current_cache_mapping[current_id] = (new_past_key_values_id, new_past_row_id)

        past_key_values_id, past_row_id = current_cache_mapping[current_id]
        proto_example = prefilled_examples[(past_key_values_id, past_row_id)]
        proto_example['num_samples'] -= 1
        example = get_example(proto_example)
        current_examples[current_id] = example
        for key, tensor in list(inputs.items()):
            inputs[key][current_id] = example['inputs'][key]


    current_cache_mapping = dict()
    current_examples = list()
    for current_id in range(gen_batch_size):
        current_examples.append(None)
        current_cache_mapping[current_id] = (None, None)
        activate_example(current_id)

    tgenerate = time.time()
    num_iterations = 0
    while any(current_examples):
        num_iterations += 1
        if num_iterations % 1024 == 0:
            gc.collect()
            torch.cuda.empty_cache()

        output = model(**inputs, past_key_values=past_key_values, use_cache=False)
        next_logits = output.logits
        del output

        if temperature > 0.0:
            _next_logits = next_logits.to(torch.float32)  # only one token
            _next_logits = _next_logits[:, -1, :]
            probs = torch.nn.functional.softmax(_next_logits / temperature, dim=-1)
            for i, example in enumerate(current_examples):
                if example is None: continue

            next_token_sample = torch.multinomial(probs, num_samples=1)  # sample the logits to get the next token
        else:
            next_token_sample = torch.argmax(next_logits[:, -1:], dim=-1)  # (4, 1)

        inputs['input_ids'].copy_(next_token_sample)
        inputs['position_ids'] += 1
        inputs['cache_position'] += 1 # gets bigger than 1600 ?!

        next_token_list = next_token_sample.tolist()
        all_finished = all([e is None for e in current_examples])
        for i, example in enumerate(current_examples):
            if example is None:
                inputs['position_ids'][i, 0] = 0
                inputs['cache_position'][i] = 0
                continue

            next_token_example = next_token_list[i][0]
            generated_ids = example['generated_ids']
            generated_ids.append(next_token_example)

            stopping_reason = None
            if next_token_example == eos_token_id:
                stopping_reason = 'eos'
            elif len(example['generated_ids']) >= max_new_tokens:
                stopping_reason = 'max_length'

            if stopping_reason is not None:
                example['stopping_reason'] = stopping_reason
                all_finished = finish_example(example)
                if all_finished:
                    break
                activate_example(current_id=i)
        if all_finished:
            break

    del past_key_values
    gc.collect()
    torch.cuda.empty_cache()

    sgenerate = time.time() - tgenerate
    metrics = dict()
    metrics['sprefill'] = round(sprefill, 1)
    metrics['sgenerate'] = round(sgenerate, 1)

    return output_examples, metrics
