import torch
import numpy as np
import math
from torch.nn import functional as F
import wandb
from ssd.configs import GenerateConfig, UnconstrainedGenerationConfig, ControlledGenerationConfig
from ssd.early_exit import EarlyExit


def get_time_variables(t, total_t, device): # cosine schedule
    def ft(small_t, big_t, s=1e-4):
        return torch.cos((small_t / big_t + s) / (1 + s) * math.pi / 2) ** 2
    alpha_t_minus_bar = ft(t-1, total_t) / ft(torch.zeros(t.shape).to(device), total_t)
    return alpha_t_minus_bar

def apply_controlling_drift(config, perturbed_inputs_diralpha):
    with torch.enable_grad():
        perturbed_inputs_diralpha_4ctr = perturbed_inputs_diralpha.clone()
        perturbed_inputs_diralpha_4ctr.requires_grad_()
        perturbed_inputs_simplex_4ctr = torch.nn.functional.softmax(perturbed_inputs_diralpha_4ctr, dim=-1)
        perturbed_inputs_embeds_4ctr = torch.nn.functional.linear(perturbed_inputs_simplex_4ctr, config.ctr_model.get_input_embeddings().weight.t())
        ctr_loss = -torch.nn.functional.log_softmax(config.ctr_model(inputs_embeds=perturbed_inputs_embeds_4ctr).logits, dim=-1)[:,config.ctr_opt_label_idx].mean()
        ctr_delta = -torch.autograd.grad(ctr_loss, perturbed_inputs_diralpha_4ctr)[0]
    perturbed_inputs_diralpha = perturbed_inputs_diralpha + config.decode_ctr_lr * ctr_delta # we use a fixed balancing factor in this work, which can be improved in the future

    return perturbed_inputs_diralpha

def logits_sampling_projection(config, logits):
    assert len(logits.size()) == 3
    probs = torch.nn.functional.softmax(logits, dim=-1)
    sorted_probs, indices = torch.sort(probs, dim=-1, descending=True)
    cum_sum_probs = torch.cumsum(sorted_probs, dim=-1)
    nucleus = cum_sum_probs < config.top_p
    nucleus = torch.cat([nucleus.new_ones(nucleus.shape[:-1] + (1,)), nucleus[..., :-1]], dim=-1)
    valid_indices = nucleus.scatter(2, indices, nucleus)

    filtered_logits = logits.masked_fill(valid_indices == 0, -float('Inf'))
    m = torch.distributions.categorical.Categorical(logits=filtered_logits)
    selected = m.sample()
    return 2 * config.one_hot_value * torch.nn.functional.one_hot(selected, logits.size(2)) - config.one_hot_value

def prompts_to_input_ids(tokenizer, accelerator, prompts, context_size):
    input_ids = tokenizer(prompts, max_length=context_size, add_special_tokens=False, return_tensors='pt', truncation=True, padding=True)['input_ids']
    # print(input_ids.shape)
    # input_ids = input_ids.filter(lambda x: len(x) != 32)
    return input_ids.to(accelerator.device)

@torch.no_grad()
def decode(config, prompts, batch_size, context_size=32, strategy="logging"):
    from collections import defaultdict
    input_ids = prompts_to_input_ids(config.tokenizer, config.accelerator, prompts, context_size)
    # print("input_ids.shape", input_ids.shape)
    # assert config.max_seq_length - config.decode_total_gen_len - context_size >= 0, "check the length of the prompt"

    decode_truncate_len = config.max_seq_length - config.decode_total_gen_len - context_size
    config.model.eval()

    if context_size > 0:
        context_input_ids = input_ids[:, :context_size].clone()
        model_embedding_lut = config.accelerator.unwrap_model(config.model).get_input_embeddings()
        # print("unit_context_input_ids", unit_context_input_ids)
        context_inputs_embeds = model_embedding_lut(context_input_ids)

    unit_seq_len = config.max_seq_length - context_size - decode_truncate_len
    xt = config.one_hot_value * torch.normal(0, 1, size=(batch_size, unit_seq_len, config.vocab_size)).to(config.accelerator.device) # unit_noise
    # print("xt.shape", xt.shape)

    accumulators = defaultdict(list)
    prev_values = {}
    patience_counter = torch.zeros(batch_size, dtype=torch.long)
    
    # exit_mask = torch.zeros(batch_size, device=config.device, dtype=torch.bool),
    sampled_sequences = []
    time_range = np.arange(1, config.total_t+1)[::-1]
    for t in time_range:
        selected_t = torch.FloatTensor([t]).repeat(batch_size).to(config.accelerator.device)
        perturbed_inputs_simplex = torch.nn.functional.softmax(xt, dim=-1)
        perturbed_inputs_embeds = config.embedding_sum_layer(perturbed_inputs_simplex)

        t_progress = selected_t / config.total_t
        timestep_embeds = config.timestep_layer(t_progress.view(batch_size,1,1).repeat(1,unit_seq_len,1))

        diffusion_embeds = perturbed_inputs_embeds + timestep_embeds

        if context_size > 0:
            diffusion_embeds = torch.cat((context_inputs_embeds, diffusion_embeds), dim=1)
            # print("diffusion_embeds.shape", diffusion_embeds.shape)
        
        outputs = config.model(inputs_embeds=diffusion_embeds, output_hidden_states=False)
        logits = outputs.logits
        logits_acc = logits.detach().cpu()
        log_p = F.log_softmax(logits_acc, -1)
        entropy = -(torch.exp(log_p) * log_p).sum(-1).mean(1)
        accumulators['entropy'] += [entropy]
        log_values = False
        if 'log_p' in prev_values.keys():
            log_values = True
            prev_log_p = prev_values['log_p']
            kl = F.kl_div(prev_log_p, log_p, log_target=True, reduction='batchmean')
            accumulators['kl'] += [kl]
            current_tokens = log_p.argmax(-1)
            prev_tokens = prev_log_p.argmax(-1)
            same_tokens = (current_tokens == prev_tokens).sum(1) == log_p.size(1)
            patience_counter[same_tokens] += 1
            patience_counter[~same_tokens] = 0
            accumulators['patience'] += [patience_counter.sum()]
        accumulators['tokens'] += [log_p.argmax(-1).numpy()]
        prev_values['log_p'] = log_p.clone()
        wandb.log(
            {
                'entropy': entropy.mean(),
                'kl': kl.mean() if log_values else None,
                'patience': patience_counter.sum()
            }
        )
        # print("logits.shape", logits.shape)

        if context_size > 0:
            logits = logits[:, context_size:].contiguous()
            # print("logits.shape", logits.shape)
            # print("w context logits shape", logits.shape)

        if config.controlled:
            logits = apply_controlling_drift(config, logits)


        projected_logits = logits_sampling_projection(config, logits)
        # print("projected_logits.shape", projected_logits.shape)
        # print("projected_logits.shape", projected_logits.shape)

        alpha_t_minus_bar = get_time_variables(selected_t, config.total_t, config.accelerator.device)
        # print("alpha_t_minus_bar.shape", alpha_t_minus_bar.shape)
        xt = torch.sqrt(alpha_t_minus_bar).view(-1, 1, 1) * projected_logits
        # print("xt.shape", xt.shape)
        zt = config.one_hot_value * torch.normal(0, 1, size=(batch_size, unit_seq_len, config.vocab_size)).to(config.accelerator.device)
        # print("zt.shape", zt.shape)
        xt = xt + torch.sqrt(1 - alpha_t_minus_bar).view(-1, 1, 1) * zt
        # print("xt.shape", xt.shape)
        simplex = torch.nn.functional.softmax(xt, dim=-1)
        # print("simplex.shape", simplex.shape)
        real_token_ids_list = torch.argmax(simplex, dim=-1).view(batch_size, unit_seq_len)
        # print("real_token_ids_list.shape", real_token_ids_list.shape)
        # print("real_token_ids_list.shape", real_token_ids_list.shape)

        # early_exit.update_state(logits=logits.detach())
        # if strategy != "logging":
        #     exit_mask = early_exit.update_exit_mask(strategy)
        #     # TODO

        if context_size > 0:
            generated_ids = torch.cat((context_input_ids, real_token_ids_list), dim=1)
        else:
            generated_ids = real_token_ids_list
        
        # print("generated_ids.shape", generated_ids.shape)

        sampled_sequences.append(config.tokenizer.batch_decode(generated_ids.clone().detach().to('cpu')))
        # print(sampled_sequences[-1][:3])

    return sampled_sequences, accumulators
