# -*- coding: utf-8 -*-
"""Codebase_for_AIEG

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/15JhYlLs97mf4jNtg8DPKXIo_TbjyO0C7

**In this section of the code we download all the dependencies, tokeniser and the model.**
"""

import torch
import numpy as np
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import matplotlib.pyplot as plt
import seaborn as sns
import re
import matplotlib.colors as mcolors
from IPython.display import HTML, display

# Step 1: Load the pre-trained GPT-2 model and tokenizer
model_name = 'gpt2'  # You can also use 'gpt2-medium', 'gpt2-large', etc.
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name, output_attentions=True)

# Set the model to evaluation mode (not training)
model.eval()

"""**In this section of the code we give a prompt to the model and adjust the hyperparameters to generate a text of desired length.**"""

# Step 2: Define the prompt text (starting text for generation)
prompt_text = "The movie"

# Step 3: Tokenize the prompt text
input_ids = tokenizer.encode(prompt_text, return_tensors='pt')

# Step 4: Generate text using the model
output_sequences = model.generate(
    input_ids=input_ids,
    max_length=20,  # You can adjust this value
    num_return_sequences=1,  # Generate 1 sequence
    no_repeat_ngram_size=2,  # Avoid repetition of n-grams
    top_k=5,  # Use top-k sampling
    top_p=0.95,  # Use top-p (nucleus) sampling
    temperature=2.0,  # Adjust the creativity of the output
    do_sample=True,  # Whether to sample or use greedy decoding
)

# Step 5: Decode the generated sequences back to text
generated_text = tokenizer.decode(output_sequences[0], skip_special_tokens=True)

# Step 6: Print the generated text
print("Generated Text:")
print(generated_text)

"""**In this section we see the tokenised words and their indices**"""

# Create a dictionary to store the words and their indices
tokens = tokenizer.convert_ids_to_tokens(tokenizer.encode(generated_text))
i=0
for j in range(len(tokens)):
  print(tokens[j],"-->",i)
  i+=1

"""**In this section we do the main computation of our proposed AIEG method**"""

def fnAIG(generated_text, word_index):
  def integrated_gradients(inputs, model, tokenizer, t_idx, baseline=None, steps=50):
      input_ids = tokenizer.encode(inputs, return_tensors="pt")

      # Get the embeddings from the model's embedding layer
      embedding_layer = model.transformer.wte

      if baseline is None:
          baseline = torch.zeros_like(embedding_layer(input_ids))

      total_gradients = torch.zeros_like(embedding_layer(input_ids))

      for alpha in torch.linspace(0, 1, steps):
          interpolated_input = baseline + alpha * (embedding_layer(input_ids) - baseline)
          interpolated_input.retain_grad()  # Ensure gradients are retained for non-leaf tensor

          # Perform a forward pass with the interpolated input
          outputs = model(inputs_embeds=interpolated_input)[0]
          token_idx = t_idx  # Analyze the token
          output_score = outputs[0, token_idx].sum()

          if alpha!=0:
            EF = (output_score - prev)/(output_score + prev) # the EF factor
          prev=output_score


          # Compute gradients with respect to the interpolated input
          output_score.backward()
          gradients = interpolated_input.grad

          if alpha!=0:
            gradients = gradients*abs(EF.item())

          # Accumulate the gradients

          total_gradients += (gradients)

      final_gradients = total_gradients

      # Multiply the average gradients by the difference between input embeddings and baseline
      output_X_EF = (embedding_layer(input_ids) - baseline) * final_gradients
      return output_X_EF




  # Define your input text
  input_text = generated_text

  # Calculate modified integrated gradients
  ig = integrated_gradients(input_text, model, tokenizer, word_index)

  # Convert gradients to a more interpretable form
  ig_scores = ig.squeeze().sum(dim=-1).detach().numpy()  # Sum across embedding dimensions
  tokens = tokenizer.convert_ids_to_tokens(tokenizer.encode(input_text))
  ig_scores[word_index] = 0


  # Display the tokens and their corresponding IG scores
  print("Output x EF Scores")
  for token, score in zip(tokens, ig_scores):
      print(f"{token}: {score}")


  # removing all the negative attributions
  ig_scores = np.where(ig_scores < 0, 0, ig_scores)

  # Normalisation Values
  total_sum = np.sum(ig_scores)
  normalized_scores = ig_scores / total_sum

  # Display the tokens and their corresponding Normalised scores
  print("Normalised Scores")
  for token, score in zip(tokens, normalized_scores):
      print(f"{token}: {score}")


  # Tokenize input text
  input_text = generated_text
  input_ids = tokenizer.encode(input_text, return_tensors='pt')



  ###############################################################
  # Calculating the attention values from all the layes and heads
  # Pass input through the model
  outputs = model(input_ids)
  attentions = outputs.attentions
  # Choose the layer and head to inspect
  sum_of_attentions=[]
  number_of_layers = 12
  number_of_heads = 12
  for layer in range(number_of_layers):
    temp = []
    for head in range(number_of_heads):
      # Get the attention matrix for the chosen layer and head
      attention_matrix = attentions[layer][0, head]

      # Get the attention values for the chosen word with respect to all other words
      attention_values = attention_matrix[word_index].detach().numpy()

      temp.append(attention_values)

    sum_of_attentions.append((np.sum(temp, axis=0))/number_of_heads)

  # Convert token IDs back to words
  tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
  # prompt: sum of all the array in the temp list

  sum_of_attentions = np.sum(temp, axis=0)
  sum_of_attentions/=number_of_layers
  # Display the tokens and their corresponding Attention values
  temp = sum_of_attentions
  print("Attention")
  for token, score in zip(tokens, temp):
      print(f"{token}: {score}")

  temp = sum_of_attentions

  # Calculating the corresponding AIEG scores for the tokens
  contribution_scores = []
  for i in range(len(sum_of_attentions)):
      contribution_scores.append(sum_of_attentions[i] * normalized_scores[i])


  token_contribution_dict = dict(zip(tokens, contribution_scores))
  # Displaying the AIEG Contribution Score for the tokens
  temp = contribution_scores
  print("Contribution Scores")
  for token, score in zip(tokens, temp):
      print(f"{token}: {score}")



  ##################################################################################################################################
  # This portion of the code merge the tokens of a word(if splitted during tokenisation) and also converts the tokens into words.
  # It also adds the contribution values of the splitted words into a single value.
  # Initialize variables
  word_contributions = {}
  current_word = ""
  current_score = 0.0

  for token, score in token_contribution_dict.items():
      # Check if the token starts with Ġ (space) or is a new word
      if token.startswith('Ġ') or token.startswith('Ċ') or (current_word != "" and not re.match(r'\w', token)):
          # If there's an existing word being built, store it
          if current_word:
              word_contributions[current_word] = current_score

          # Start a new word
          current_word = token.lstrip('ĠĊ')
          current_score = score
      else:
          # Continue building the current word
          current_word += token
          current_score += score

  # Add the last word to the dictionary
  if current_word:
      word_contributions[current_word] = current_score

  # Output the final dictionary of words and their summed contribution scores
  print(word_contributions)
  print(generated_text)


  ##########################################################################
  # In this section we do the color coding of the text with their AIEG values
  def plot_word_contributions(zip_object):
      # Convert zip object to dictionary
      contributions = dict(zip_object)

      # Sort contributions by score (if desired)
      sorted_contributions = contributions  # or dict(sorted(contributions.items(), key=lambda item: item[1], reverse=True))

      # Create a color map that ranges from light green to dark green
      cmap = plt.get_cmap("Greens")

      # Get the contribution scores and normalize them
      scores = list(sorted_contributions.values())
      max_score = max(scores)
      min_score = min(scores)

      def score_to_color(score):
          # Normalize score between 0 and 1
          norm_score = (score - min_score) / (max_score - min_score)
          # Convert normalized score to a color (light green to dark green)
          color = cmap(norm_score)
          return mcolors.to_hex(color[:3])  # Convert RGB to HEX

      # Generate HTML-like string for visualization
      html_output = "<html><body>"
      for word, score in sorted_contributions.items():
          color_hex = score_to_color(score)
          html_output += f'<span style="background-color: {color_hex}; color: black; font-size: 20px; margin-right: 5px; padding: 2px; border-radius: 3px;">{word}</span>'
      html_output += "</body></html>"

      # Display HTML output in the notebook
      display(HTML(html_output))

  # Call the function to create the visualization
  plot_word_contributions(word_contributions)

"""**Here we choose the token of interest for the AIEG value calculation**"""

#word index of the interested word
word_index = 15 # give the value here for the token of interest
print("word index-->",word_index)
fnAIG(generated_text, word_index)

