import os
from pathlib import Path
from dataclasses import dataclass

import numpy as np
from transformers import GPT2LMHeadModel, GPTNeoForCausalLM, GPTJForCausalLM, OPTForCausalLM
import torch
import torch.nn as nn

import promptrl.utils as utils

@dataclass
class GPTMixinGenerateOutput:
    output_ids : torch.tensor
    past_key_values : tuple

@dataclass
class GPTMixinGenerateBeamOutput:
    output_ids : torch.tensor
    scores : torch.tensor
    seq_lengths : torch.tensor

class GPTPromptInputMixin:
    def batch_score(
        self,
        input_ids,
        attention_mask=None,
        past_key_values=None,
        batch_size=8,
    ):
        N, T = input_ids.shape
        if past_key_values is not None:
            assert past_key_values[0][0].shape[0] == 1
            past_key_values = [[head.expand(N, -1, -1, -1) for head in layer] for layer in past_key_values]
        output_logits = []
        for start_idx in range(0, N, batch_size):
            inputs = {'input_ids': input_ids[start_idx:start_idx+batch_size]}
            if attention_mask is not None:
                inputs['attention_mask'] = attention_mask[start_idx:start_idx+batch_size]
            if past_key_values is not None:
                inputs['past_key_values'] = [[head[start_idx:start_idx+batch_size] for head in layer] for layer in past_key_values]

            logits = self.forward(**inputs).logits
            scores = utils.shifted_cross_ent(inputs['input_ids'], logits, loss_mask=inputs['attention_mask'][:,-T:])
            output_logits.append(-scores.sum(1))
        return torch.cat(output_logits, dim=0)

    def generate(
        self,
        inputs_embeds=None,
        attention_mask=None,
        prefix_ids=None,
        past_key_values=None,
        max_length=100,
        eos_token_id=None,
        pad_token_id=None
    ):
        pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
        eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id

        cur_len = 0

        if inputs_embeds is None:
            assert past_key_values is not None
            inputs_embeds = torch.zeros((past_key_values[0][0].shape[0], 0, self.config.n_embd), device=past_key_values[0][0].device)

        unfinished_sequences = torch.full((inputs_embeds.shape[0],), 1, device=inputs_embeds.device)
        if attention_mask is None:
            attn_len = inputs_embeds.shape[1]
            if past_key_values is not None:
                attn_len += past_key_values[0][0].shape[2]
            attention_mask = torch.full((inputs_embeds.shape[0], attn_len), 1, device=inputs_embeds.device)

        if prefix_ids is None:
            output_ids = torch.zeros((inputs_embeds.shape[0], 0), dtype=torch.int64, device=inputs_embeds.device)
        else:
            output_ids = prefix_ids
            prefix_embeds = self._get_wte()(prefix_ids)
            inputs_embeds = torch.cat((inputs_embeds, prefix_embeds), dim=1)
            attention_mask = nn.functional.pad(attention_mask, (0, prefix_ids.shape[1]), 'constant', 1)

        while True:
            outputs = self.forward(inputs_embeds=inputs_embeds, attention_mask=attention_mask, past_key_values=past_key_values, use_cache=True)
            past_key_values = outputs.past_key_values
            next_token_logits = outputs.logits[:, -1, :]
            next_tokens = torch.argmax(next_token_logits, dim=-1)

            # finished sentences should have their next token be a padding token
            if eos_token_id is not None:
                if pad_token_id is None:
                    raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
                next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)

            output_ids = torch.cat([output_ids, next_tokens[:, None]], dim=-1)
            attention_mask = nn.functional.pad(attention_mask, (0, 1), "constant", 1)

            next_embeds = self._get_wte()(next_tokens[:, None])
            inputs_embeds = next_embeds# previous inputs_embeds is cached in past_key_values
            cur_len += 1

            # if eos_token was found in one sentence, set sentence to finished
            if eos_token_id is not None:
                unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())

            if unfinished_sequences.max() == 0 or cur_len >= max_length:
                break
        return GPTMixinGenerateOutput(output_ids, past_key_values)

    def generate_beam(
        self,
        inputs_embeds=None,
        prefix_ids=None,
        past_key_values=None,
        max_length=100,
        n_beams=30,
        temperature=1.,
        eos_token_id=None,
        pad_token_id=None
    ):
        assert temperature > 0
        pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
        eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id

        if inputs_embeds is None:
            assert past_key_values is not None
            assert past_key_values[0][0].shape[0] == 1
            inputs_embeds = torch.zeros((1, 0, self.config.n_embd), device=past_key_values[0][0].device)
        else:
            assert inputs_embeds.shape[0] == 1

        if prefix_ids is None:
            output_ids = torch.zeros((n_beams, 0), dtype=torch.int64, device=inputs_embeds.device)
        else:
            output_ids = prefix_ids.expand((n_beams, -1))
            prefix_embeds = self._get_wte()(prefix_ids)
            inputs_embeds = torch.cat((inputs_embeds, prefix_embeds), dim=1)

        unfinished_sequences = torch.ones((n_beams,), dtype=torch.bool, device=inputs_embeds.device)
        seq_lengths = torch.full((n_beams,), output_ids.shape[1], device=inputs_embeds.device)
        scores = None

        # code adapted from clipcap beam search
        for i in range(max_length):
            outputs = self.forward(inputs_embeds=inputs_embeds, past_key_values=past_key_values)
            #past_key_values = outputs.past_key_values
            #past_key_values = None
            logits = outputs.logits[:, -1, :] / temperature
            logits = logits.softmax(-1).log()
            if scores is None:# first entry
                scores, next_tokens = logits.topk(n_beams, -1)
                scores = scores.squeeze(0)
                next_tokens = next_tokens.permute(1, 0)
                output_ids = torch.cat((output_ids, next_tokens), dim=-1)
                seq_lengths[unfinished_sequences] += 1
                inputs_embeds = inputs_embeds.expand((n_beams, -1, -1))
                past_key_values = [[kv.expand(n_beams, -1, -1, -1) for kv in layer] for layer in past_key_values]
            else:
                logits[~unfinished_sequences] = -float(np.inf)
                logits[~unfinished_sequences, pad_token_id] = 0
                scores_sum = scores[:, None] + logits
                seq_lengths[unfinished_sequences] += 1
                scores_sum_average = scores_sum / seq_lengths[:, None]
                scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(n_beams, -1)
                next_tokens_source = torch.div(next_tokens, scores_sum.shape[1], rounding_mode='floor')
                next_tokens = next_tokens % scores_sum.shape[1]
                next_tokens = next_tokens.unsqueeze(1)
                output_ids = output_ids[next_tokens_source]
                output_ids = torch.cat((output_ids, next_tokens), dim=1)
                inputs_embeds = inputs_embeds[next_tokens_source]
                scores = scores_sum_average * seq_lengths
                unfinished_sequences = unfinished_sequences[next_tokens_source]

            next_embeds = self._get_wte()(next_tokens)
            inputs_embeds = torch.cat((inputs_embeds, next_embeds), dim=1)

            if eos_token_id is not None:
                unfinished_sequences = unfinished_sequences & (next_tokens.squeeze(1) != eos_token_id)

            if not unfinished_sequences.any():
                break

        return GPTMixinGenerateBeamOutput(output_ids, scores, seq_lengths)


class GPT2PromptInputLM(GPTPromptInputMixin, GPT2LMHeadModel):
    def __init__(self, config):
        super().__init__(config)
    def _get_wte(self):
        return self.transformer.wte

class GPTJPromptTuningLM(GPTPromptInputMixin, GPTJForCausalLM):
    def __init__(self, config):
        super().__init__(config)
    def _get_wte(self):
        return self.transformer.wte

class GPTNeoPromptTuningLM(GPTPromptInputMixin, GPTNeoForCausalLM):
    def __init__(self, config):
        super().__init__(config)
    def _get_wte(self):
        return self.transformer.wte

class OPTPromptTuningLM(GPTPromptInputMixin, OPTForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        # translate config to gpt2 format
        self.config.n_layer = self.config.num_hidden_layers
        self.config.n_embd = self.config.word_embed_proj_dim
        self.config.n_head = self.config.num_attention_heads
        self.config.n_positions = self.config.max_position_embeddings
    def _get_wte(self):
        return self.model.decoder.embed_tokens
