# -*- coding: utf-8 -*-
"""ntk_laziness.ipynb

Automatically generated by Colab.


"""


!pip install nbconvert
!pip install transformers
!pip install datasets
import torch
from transformers import AutoModel, AutoTokenizer, BertForSequenceClassification
from functorch import jacrev, make_functional_with_buffers
import gc
from torch import nn
from torch.nn.functional import relu
from torch.autograd import grad

import numpy as np
import torch
from datasets import load_dataset#, load_metric
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, TrainingArguments, Trainer
from torch.utils.data import Dataset
import logging

from datasets import load_dataset
import matplotlib.pyplot as plt
plt.rcParams['font.family'] = 'serif'

"""##NTK for laziness from pkl files##

##Plotting the NTK from Malladi et al paper##
"""

def Kernel_Whole_SIM(tensor):
      # Get the dimensions of the input tensor
      N, M, P, Q = tensor.shape
      # Extract sub-tensors for a and b
      a = tensor[:, 0,:, 0]
      b = tensor[:, 0,:, 1]
      c= tensor[:, 1,:, 0]
      d= tensor[:, 1,:, 1]

      ab=torch.cat([a, b], dim=1)
      cd=torch.cat([c, d], dim=1)

      result_tensor = torch.cat([ab, cd], dim=0)
      #result_tensor = tensor.permute(0, 2, 1, 3).reshape(N * P, M * Q)
      #result_tensor = tensor.reshape(N * M, P * Q)
      #a=[tensor[:,:,0,0], tensor[:,:,0,1]]
      #b=[tensor[:,:,1,0], tensor[:,:,1,1]]
      #torch.stack(a,b,dim=0)
      return result_tensor

def Kernel_Whole_ntk_convention(tensor):
      # Get the dimensions of the input tensor
      N, M, P, Q = tensor.shape

      #result_tensor = tensor.permute(0, 2, 1, 3).reshape(N * P, M * Q)
      result_tensor = tensor.reshape(N * M, P * Q)
      return result_tensor

# Commented out IPython magic to ensure Python compatibility.
# %ls

import torch
import matplotlib.pyplot as plt
import seaborn as sns
plt.rcParams['font.family'] = 'serif'

# Specify the path to your file
task='cr'
task_c='cr'
model='roberta-base'
method='row'
#method='col'
instance="train"


file_path = f'result_{method}_subset/{task_c}-{model}-prompt-kernel-kernel-prompting/16-13/{instance}_kernels_{task}.pt'
# Load the PyTorch model or tensor
data = torch.load(file_path)

file_path = f'result_{method}_subset/{task_c}-{model}-prompt-kernel-kernel-prompting/16-13/test_results_{task}.txt'
with open(file_path, 'r') as file:
    content = file.read()
print(content)


# If it's a model state_dict, you can print its keys
if isinstance(data, dict):
    print(data.keys())
else:
    print(data[0].shape)
    print(data[1].shape)


#ntk=Kernel_Whole_SIM(data[0])
ntk=Kernel_Whole_ntk_convention(data[0])
print('ntk',ntk.shape)


N, M = ntk.shape
# Use seaborn's heatmap function to plot
plt.figure(figsize=(10, 8))  # Set the size of the figure
sns.heatmap(ntk, cmap="viridis")  # You can change the color map to 'coolwarm', 'magma', etc.
plt.title("Heatmap of NTK")
plt.xlabel("Test Samples")
plt.ylabel("Train Samples")

plt.ylabel(r"NTK index $i$")
plt.xlabel(r"NTK index $j$")

# Set the tick locations and labels
#plt.xticks([0, M/2, M], ['0', '1', '2'])  # Replace 40 with 1 and 78 with 2
#plt.yticks([0, N/2, N], ['0', '1', '2'])  # Replace 40 with 1 and 78 with 2
plt.xticks([0, M/2, M],[f'{0}',f'{M//2}', f'{M}'])  # Replace 40 with 1 and 78 with 2
plt.yticks([0, N/2, N],[f'{0}',f'{N//2}', f'{N}'])  # Replace 40 with 1 and 78 with 2
#plt.yticks([0, N/2, N])  # Replace 40 with 1 and 78 with 2
#plt.title(f'{task_c}')
plt.title('CR')
plt.savefig(f'result_{method}_subset/NTK_{task_c}_{method}.eps', format='eps', dpi=300)
plt.show()

def Kernal_Whole(tensor):
      # Get the dimensions of the input tensor
      N, M, P, Q = tensor.shape
      # Extract sub-tensors for a and b
      a = tensor[:, :, 0, 0]
      b = tensor[:, :, 0, 1]
      c= tensor[:, :, 1, 0]
      d= tensor[:, :, 1, 1]

      ab=torch.cat([a, b], dim=1)
      cd=torch.cat([c, d], dim=1)

      result_tensor = torch.cat([ab, cd], dim=0)
      #result_tensor = tensor.permute(0, 2, 1, 3).reshape(N * P, M * Q)
      #a=[tensor[:,:,0,0], tensor[:,:,0,1]]
      #b=[tensor[:,:,1,0], tensor[:,:,1,1]]
      #torch.stack(a,b,dim=0)
      return result_tensor

# Function to tokenize input text
def tokenize_input(texts, tokenizer, max_length=10):
    return tokenizer(
        texts,
        padding='max_length',
        truncation=True,
        max_length=max_length,
        return_tensors="pt"
    )
