import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from torchmetrics.functional.classification import binary_accuracy, multiclass_accuracy
from transformers import AutoConfig, AutoTokenizer, T5EncoderModel


class PromptEncoder(pl.LightningModule):
    def __init__(self, config):
        super().__init__()

        self.prompt_len = config.prompt_max_tokens
        # The prompt encoder is a frozen T5 model, frozen in configure_optimizers()
        model_path = os.path.join(os.path.dirname(os.getcwd()), "t5-base")
        self.config = AutoConfig.from_pretrained(model_path)
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.model = T5EncoderModel.from_pretrained(model_path)
        self.model_out = nn.Linear(self.config.d_model, config.n_embd)

    def forward(self, prompts, ptm_mode=False):
        """
        prompts: a batch of strings
        ptm_mode: In PTM mode, add a [CLS] token at the beginning of the prompt

        Return:
            prompt_embed: (B, prompt_len, hidden)
            prompt_encodings: a dict of tensors
                {
                    input_ids: (B, prompt_len)
                    attention_mask: (B, prompt_len)
                }
        """
        # if ptm_mode:
        #     prompts = ["[CLS] " + prompt for prompt in prompts]

        prompt_inputs = self.tokenizer( prompts, 
                                        padding='max_length',
                                        max_length=self.prompt_len,
                                        return_tensors="pt", 
                                        )
        prompt_inputs = {k: v.to(self.device) for k, v in prompt_inputs.items()}
        prompt_hidden_states = self.model(**prompt_inputs).last_hidden_state
        prompt_embed = self.model_out(prompt_hidden_states) # (B, prompt_len, hidden)

        return prompt_embed, prompt_inputs['attention_mask'], prompt_inputs['input_ids']

    def get_all_prompt_embed(self, all_prompts, segment_lengths):
        if not hasattr(self, "all_prompts_embed"):
            # Initialize self.all_prompts_embed if it is not initialized
            self.all_prompts_embed, all_prompt_attn_mask, all_prompt_input_ids = self(all_prompts, ptm_mode=True)
            self.all_prompts_attn_mask = all_prompt_attn_mask # (all_prompt_num, prompt_len)
            # Expand self.all_prompts_embed so that it can be compared with prompt_embed
            self.all_prompts_expand = all_prompt_input_ids.unsqueeze(0).expand(segment_lengths['batch_size'], -1, -1) # (B, all_prompt_num, prompt_len)

            # if not detach() will cause the following error:
            # RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
            self.all_prompts_expand = self.all_prompts_expand.detach()
            self.all_prompts_embed = self.all_prompts_embed.detach()
            self.all_prompts_attn_mask = self.all_prompts_attn_mask.detach()

        return self.all_prompts_expand, self.all_prompts_embed, self.all_prompts_attn_mask


def masked_modeling_attention_mask(attention_mask, mask_ratio):
    """
    Parameters:
        attention_mask: In transformer attention masks, 1 means kept, 0 means masked
            (B, S, seg_len) where seg_len can be state_len, action_dim, or prompt_len
        mask_ratio: the ratio of tokens to be masked
    """
    mask_prob = torch.ones_like(attention_mask, device = attention_mask.device) * mask_ratio
    loss_mask = torch.bernoulli(mask_prob).to(attention_mask.device) # In loss_mask, 1 means masked, 0 means kept
    loss_mask = loss_mask * attention_mask # Do not need to calcualte loss for the padding tokens, which are 0 in attention_mask
    masked_modeling_attn_mask = attention_mask * (1 - loss_mask)

    return masked_modeling_attn_mask, loss_mask


class MaskedPromptModelingHead(pl.LightningModule):
    def __init__(self, hidden_size, tokenizer, config):
        super().__init__()
        self.tokenizer = tokenizer
        self.lm_head = nn.Linear(hidden_size, config.vocab_size)
        # 32128 is the vocab_size of T5, https://huggingface.co/transformers/model_doc/t5.html#t5encodermodel

    def prepare_inputs(self, attn_mask, mask_ratio):
        prompt_attn_mask, prompt_loss_mask = masked_modeling_attention_mask(attn_mask, mask_ratio)
        return prompt_attn_mask, prompt_loss_mask

    def forward(self, hidden_states, prompt_input_ids, prompt_loss_mask, segment_lengths):

        B = segment_lengths['batch_size']
        B_prime = segment_lengths['batch_size_prime']
        prompt_length = segment_lengths['prompt_length']

        # Extract the corresponding hidden states and calculate the logits
        prompt_hidden_states = hidden_states.view(B_prime, -1, hidden_states.size(-1))[:B, :prompt_length, :] # (B, prompt_len, hidden)
        masked_prompt_hidden_states = torch.masked_select(prompt_hidden_states, prompt_loss_mask.unsqueeze(-1).bool())
        logits = self.lm_head(masked_prompt_hidden_states.view(-1, hidden_states.size(-1)))

        if logits.size(0) == 0:
            # If no tokens are masked, return a dummy loss
            loss_mpm = torch.tensor(0.0, device = self.device)
            accuracy = torch.tensor(0.0, device = self.device)
            return loss_mpm, accuracy

        # Construct training targets and calculate loss
        true_token_ids = torch.masked_select(prompt_input_ids, prompt_loss_mask.bool())
        loss_mpm = F.cross_entropy(logits, true_token_ids)
        accuracy = multiclass_accuracy(logits, true_token_ids, num_classes = logits.size(-1), average = 'micro')
        accuracy = accuracy.detach().cpu().item()

        if self.global_rank == 0:
            pred_token_ids = torch.argmax(logits, dim = -1)
            pred_token_ids = pred_token_ids.detach().cpu().numpy()
            pred_tokens = self.tokenizer.batch_decode(pred_token_ids)
            true_tokens = self.tokenizer.batch_decode(true_token_ids.detach().cpu().numpy())
            # Log pred and true tokens in tensorboard
            self.logger.experiment.add_text('pretrain_mpm_gen/pred_tokens', str(pred_tokens), self.global_step)
            self.logger.experiment.add_text('pretrain_mpm_gen/true_tokens', str(true_tokens), self.global_step)

        return loss_mpm, accuracy


class PromptTrajectoryMatchingHead(pl.LightningModule):
    def __init__(self, hidden_size):
        super().__init__()

        self.match_score = nn.Linear(hidden_size, 1)
    
    # def prepare_inputs(self, all_prompts_inputs, prompt_input_ids, prompt_embed, prompt_attn_mask, state_embed_list, state_attn_mask_list, action_embed):
    #     all_prompts_expand, all_prompts_embed, all_prompts_attn_mask = all_prompts_inputs

    #     # prompt_input_ids: (B, prompt_len)
    #     prompt_input_ids_expand = prompt_input_ids.unsqueeze(1).expand(-1, all_prompts_expand.size(1), -1) # (B, all_prompt_num, prompt_len)
    #     # Expand prompt_embed so that it can be compared with self.all_prompts_embed
    #     same_prompts = all_prompts_expand == prompt_input_ids_expand # (B, all_prompt_num, prompt_len)
    #     # Aggregate the negative_mask using the logical AND operation
    #     # along the prompt_len dimension and the hidden dimension 
    #     same_prompts = same_prompts.all(dim = -1) # (B, all_prompt_num)
    #     negative_prompts = 1 - same_prompts.float() # (B, all_prompt_num)

    #     cat_dist = torch.distributions.categorical.Categorical(probs=negative_prompts)
    #     negative_indices = cat_dist.sample() # (B,)
    #     negative_prompt_embed = all_prompts_embed[negative_indices] # (B, prompt_len, hidden)
    #     negative_prompt_attn_mask = all_prompts_attn_mask[negative_indices] # (B, prompt_len)

    #     # Append negative embeddings
    #     prompt_embed = torch.cat([prompt_embed, negative_prompt_embed], dim = 0)
    #     for i, image_embed in enumerate(state_embed_list):
    #         state_embed_list[i] = image_embed.tile(2, 1, 1, 1)
    #     action_embed = action_embed.tile(2, 1, 1, 1)

    #     # Append negative attention masks
    #     prompt_attn_mask = torch.cat([prompt_attn_mask, negative_prompt_attn_mask], dim = 0)
    #     for i, image_attn_mask in enumerate(state_attn_mask_list):
    #         state_attn_mask_list[i] = image_attn_mask.tile(2, 1, 1)

    #     return prompt_embed, prompt_attn_mask, state_embed_list, state_attn_mask_list, action_embed

    def prepare_inputs(self, all_prompts_inputs, prompt_input_ids):
        all_prompts_expand, all_prompts_embed, all_prompts_attn_mask = all_prompts_inputs

        # prompt_input_ids: (B, prompt_len)
        prompt_input_ids_expand = prompt_input_ids.unsqueeze(1).expand(-1, all_prompts_expand.size(1), -1) # (B, all_prompt_num, prompt_len)
        # Expand prompt_embed so that it can be compared with self.all_prompts_embed
        same_prompts = all_prompts_expand[:prompt_input_ids_expand.size(0), ...] == prompt_input_ids_expand # (B, all_prompt_num, prompt_len)
        # Aggregate the negative_mask using the logical AND operation
        # along the prompt_len dimension and the hidden dimension 
        same_prompts = same_prompts.all(dim = -1) # (B, all_prompt_num)
        negative_prompts = 1 - same_prompts.float() # (B, all_prompt_num)

        cat_dist = torch.distributions.categorical.Categorical(probs=negative_prompts)
        negative_indices = cat_dist.sample() # (B,)
        negative_prompt_embed = all_prompts_embed[negative_indices] # (B, prompt_len, hidden)
        negative_prompt_attn_mask = all_prompts_attn_mask[negative_indices] # (B, prompt_len)

        return negative_prompt_embed, negative_prompt_attn_mask

    def forward(self, hidden_states, segment_lengths):
        B = segment_lengths['batch_size']
        B_prime = segment_lengths['batch_size_prime']

        # Extract the corresponding hidden states and calculate the logits
        hidden_states = hidden_states.view(B_prime, -1, hidden_states.size(-1))[:, 0, :] # (B*2, hidden)
        logits = self.match_score(hidden_states) # (B*2, 1)

        # Construct training targets and calculate loss
        ptm_target = torch.zeros((B_prime, 1), dtype = torch.float, device = self.device) # (B*2,)
        ptm_target[:B] = 1
        # Binary cross entropy loss
        loss_ptm = F.binary_cross_entropy_with_logits(logits, ptm_target)
        accuracy = binary_accuracy(logits, ptm_target)
        accuracy = accuracy.detach().cpu().item()

        return loss_ptm, accuracy
