import numpy as np
import torch
import sys
import ast

from transformers import LlamaTokenizer, LlamaForCausalLM
from datasets import load_dataset

# Use this GPU
device = 'cuda:6'

# Load the RedPajama training datasest
print('Load the RedPajama wikipedia subset.....')
# train_data = load_dataset("togethercomputer/RedPajama-Data-1T", "wikipedia")
# train_data = load_dataset("fancyzhx/ag_news")['train']['text'] # for testing
train_data = np.load('train_data.npy')
def filter_english_examples(example):
    meta = ast.literal_eval(example['meta'])
    return meta.get('language') == 'en'

# Apply the filter to the dataset
# train_data = train_data.filter(filter_english_examples)['train']['text']

# import pdb; pdb.set_trace()

model_path = 'openlm-research/open_llama_7b'

# Tokenizer
tokenizer = LlamaTokenizer.from_pretrained(model_path)
str2tok = lambda s: tokenizer.encode(s)
tok2str = lambda l: tokenizer.decode(l)
             
# import pdb; pdb.set_trace()
# concatenated_text = ' '.join(train_data)
# train_data = tokenizer(concatenated_text)['input_ids']
# train_data = np.array(train_data)
# # probably should save this due to very long tokenizer processing time
# np.save('train_data.npy', train_data)

# The test sequence
test_seq = np.array(str2tok('The theory of relativity is a theory of space and time. It was developed by Albert Einstein'))
input_seq = test_seq[0:len(test_seq)-1]
output_token = test_seq[len(test_seq)-1]

# Load and compile the llama model
print('Load and compile the pretrained llama model.....')
model = LlamaForCausalLM.from_pretrained(
    model_path, 
)
# model = GPT.from_pretrained('gpt2')
model.to(device)
# model = torch.compile(model)
model.eval()

# Use high precision dot product
torch.set_float32_matmul_precision('high')

# Start evaluation
ctx = torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16)
with torch.no_grad():
    with ctx:    

        print('test sequence: ' + tok2str(input_seq) + '[' + tok2str([output_token]) + ']')
        
        # Get the feature of the input sequence
        input_feature = torch.stack([torch.from_numpy((input_seq).astype(np.int64))])
        input_feature = input_feature.pin_memory().to(device, non_blocking=True)
        input_feature = model.model(input_feature).last_hidden_state[0][len(input_seq)-1]

        # Get all indices where output token appears in the data
        indices = np.where(train_data == output_token)[0]
        print('num of occurences: ' + str(len(indices)))

        # Get the weights and biases (= 0) of the class (output token)
        w = None
        for name, params in model.lm_head.named_parameters():
            w = params[output_token]
        
        # The support set to be generated (the sequences, dot produts, and logits)
        support_sequences = None
        support_dot_products = None
        support_logits = None
        
        # Consider sequences of length min_seq_len to max_seq_len 
        # that ends with the output token.
        # These are the candidate sequences
        min_seq_len = 1
        max_seq_len = len(input_seq) + 5
        
        for seq_len in range(min_seq_len, max_seq_len + 1):
        
            print('Processing sequences of length: ' + str(seq_len) + '/' + str(max_seq_len))    
            sys.stdout.flush()
        
            # Get the sequences
            sequences = torch.stack([torch.from_numpy((train_data[max(0,i-seq_len):i]).astype(np.int64)) for i in indices])
            sequences = sequences.pin_memory().to(device, non_blocking=True)
            
            # Process the sequences in batches
            batch_size = 50
            
            idx = 0
            while idx < len(indices):
                
                # The right index
                ri = min(len(indices), idx + batch_size)
                
                # Get the sequences
                batch = sequences[idx:ri]
                # print('batch size: ' + str(batch.size()))
                # import pdb; pdb.set_trace()
                # Get the features of the batch
                features = model.model(batch).last_hidden_state
                # print('features size: ' + str(features.size()))
                
                # Get the logits
                logits = model(batch).logits
                # print('logits size: ' + str(logits.size()))

                # Only need the features and logits of the full sequences
                features = torch.flatten(features, start_dim=0, end_dim=1)
                logits = torch.flatten(logits, start_dim=0, end_dim=1)
                pick_from = torch.arange(seq_len-1, (ri - idx) * seq_len, seq_len).to(device)
                features = torch.index_select(features, 0, pick_from)
                logits = torch.index_select(logits, 0, pick_from)
                logits = torch.index_select(logits, 1, torch.tensor([output_token]).to(device))
                # print('features size: ' + str(features.size()))
                # print('logits size: ' + str(logits.size()))

                # Find the indices of the support sequences
                support_indices = (torch.matmul((features - input_feature), w) < 0)
                support_indices = support_indices.nonzero().flatten()
                support_size = support_indices.size()[0]
                
                if support_size > 0:
                    
                    # Remove points that are not in the support set
                    batch = torch.index_select(batch, 0, support_indices)
                    features = torch.index_select(features, 0, support_indices)
                    logits = torch.index_select(logits, 0, support_indices)
                    # print(batch.size())
                    # print(features.size())
                    # print(logits.size())
                    
                    # Compute the dot products with the input feature
                    dot_products = torch.matmul(features, input_feature)
                    # print('dot product size: ' + str(dot_products.size()))
                    
                    # Instead of storing the sequences (token values)
                    # just store the end index and the length
                    batch_1 = torch.index_select(torch.from_numpy(indices[idx:ri]).to(device), 0, support_indices).to(device).view(support_size, 1)
                    batch_2 = torch.full((support_size, 1), seq_len).to(device)
                    batch = torch.cat((batch_1, batch_2), dim=1)

                    # Add to the list of supports found so far                    
                    if support_sequences == None:
                        support_sequences = batch
                        support_dot_products = dot_products
                        support_logits = logits
                    else:
                        support_sequences = torch.cat((support_sequences, batch))
                        support_dot_products = torch.cat((support_dot_products, dot_products))
                        support_logits = torch.cat((support_logits, logits))
            
                idx = ri
        
        # Sort the supports by dot product
        _, support_sorted_idxs = torch.sort(support_dot_products, descending=False)
        support_sequences = torch.index_select(support_sequences, 0, support_sorted_idxs)
        support_logits = torch.index_select(support_logits, 0, support_sorted_idxs)
        
        # Generate the spectrum
        spec = []
        num_supports = support_sorted_idxs.size()[0]
        print('num of supports: ' + str(num_supports))
        k = 0
        while k < num_supports:
            
            # Find the sequence with the lowest logit
            min_idx = torch.argmin(support_logits[k:])
            spec.append(support_sequences[k+min_idx.item()].tolist())
            k = k + min_idx.item() + 1
            
        spec = np.array(spec)
        print('spectrum size: ' + str(len(spec)))
        print(spec)
        
        # Save to file
        np.save('spec_Einstein', spec)
                    
