import numpy as np
import tiktoken
import torch
import sys

from model import GPT

# Use this GPU
device = 'cuda:0'

# Load the openwebtext training dataset
print('Load the openwebtext datasets.....')
train_data = np.memmap('data/openwebtext/train.bin', dtype=np.uint16, mode='r')

# Tokenizer
enc = tiktoken.get_encoding("gpt2")
str2tok = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"})
tok2str = lambda l: enc.decode(l)

# The test sequence
test_seq = np.array(str2tok('The origin of life is a hot topic among scientists. In a paper published this week in Nature Communications, researchers from the University of California, San Diego, and the University of California, Los Angeles say they have discovered a new way of making the building blocks of RNA'))
input_seq = test_seq[0:len(test_seq)-1]
output_token = test_seq[len(test_seq)-1]

# Load and compile the GPT-2 model
print('Load and compile the pretrained GPT-2 model.....')
model = GPT.from_pretrained('gpt2-xl')
# model = GPT.from_pretrained('gpt2')
model.to(device)
# model = torch.compile(model)
model.eval()

# Use high precision dot product
torch.set_float32_matmul_precision('high')

# Start evaluation
ctx = torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16)
with torch.no_grad():
    with ctx:    

        print('test sequence: ' + tok2str(input_seq) + '[' + tok2str([output_token]) + ']')
        
        # Get the feature of the input sequence
        input_feature = torch.stack([torch.from_numpy((input_seq).astype(np.int64))])
        input_feature = input_feature.pin_memory().to(device, non_blocking=True)
        input_feature = model.transform(input_feature)[0][len(input_seq)-1]
        
        # Get all indices where output token appears in the data
        indices = np.where(train_data == output_token)[0]
        print('num of occurences: ' + str(len(indices)))

        # Get the weights and biases (= 0) of the class (output token)
        w = None
        for name, params in model.lm_head.named_parameters():
            w = params[output_token]
        
        # The support set to be generated (the sequences, dot produts, and logits)
        support_sequences = None
        support_dot_products = None
        support_logits = None
        
        # Consider sequences of length min_seq_len to max_seq_len 
        # that ends with the output token.
        # These are the candidate sequences
        min_seq_len = 1
        max_seq_len = len(input_seq) + 5
        
        for seq_len in range(min_seq_len, max_seq_len + 1):
        
            print('Processing sequences of length: ' + str(seq_len) + '/' + str(max_seq_len))    
            sys.stdout.flush()
        
            # Get the sequences
            sequences = torch.stack([torch.from_numpy((train_data[max(0,i-seq_len):i]).astype(np.int64)) for i in indices])
            sequences = sequences.pin_memory().to(device, non_blocking=True)
            
            # Process the sequences in batches
            batch_size = 1000
            
            idx = 0
            while idx < len(indices):
                
                # The right index
                ri = min(len(indices), idx + batch_size)
                
                # Get the sequences
                batch = sequences[idx:ri]
                # print('batch size: ' + str(batch.size()))
                
                # Get the features of the batch
                features = model.transform(batch)
                # print('features size: ' + str(features.size()))
                
                # Get the logits
                logits = model.lm_head(features)
                # print('logits size: ' + str(logits.size()))

                # Only need the features and logits of the full sequences
                features = torch.flatten(features, start_dim=0, end_dim=1)
                logits = torch.flatten(logits, start_dim=0, end_dim=1)
                pick_from = torch.arange(seq_len-1, (ri - idx) * seq_len, seq_len).to(device)
                features = torch.index_select(features, 0, pick_from)
                logits = torch.index_select(logits, 0, pick_from)
                logits = torch.index_select(logits, 1, torch.tensor([output_token]).to(device))
                # print('features size: ' + str(features.size()))
                # print('logits size: ' + str(logits.size()))

                # Find the indices of the support sequences
                support_indices = (torch.matmul((features - input_feature), w) < 0)
                support_indices = support_indices.nonzero().flatten()
                support_size = support_indices.size()[0]
                
                if support_size > 0:
                    
                    # Remove points that are not in the support set
                    batch = torch.index_select(batch, 0, support_indices)
                    features = torch.index_select(features, 0, support_indices)
                    logits = torch.index_select(logits, 0, support_indices)
                    # print(batch.size())
                    # print(features.size())
                    # print(logits.size())
                    
                    # Compute the dot products with the input feature
                    dot_products = torch.matmul(features, input_feature)
                    # print('dot product size: ' + str(dot_products.size()))
                    
                    # Instead of storing the sequences (token values)
                    # just store the end index and the length
                    batch_1 = torch.index_select(torch.from_numpy(indices[idx:ri]).to(device), 0, support_indices).to(device).view(support_size, 1)
                    batch_2 = torch.full((support_size, 1), seq_len).to(device)
                    batch = torch.cat((batch_1, batch_2), dim=1)

                    # Add to the list of supports found so far                    
                    if support_sequences == None:
                        support_sequences = batch
                        support_dot_products = dot_products
                        support_logits = logits
                    else:
                        support_sequences = torch.cat((support_sequences, batch))
                        support_dot_products = torch.cat((support_dot_products, dot_products))
                        support_logits = torch.cat((support_logits, logits))
            
                idx = ri
        
        # Sort the supports by dot product
        _, support_sorted_idxs = torch.sort(support_dot_products, descending=False)
        support_sequences = torch.index_select(support_sequences, 0, support_sorted_idxs)
        support_logits = torch.index_select(support_logits, 0, support_sorted_idxs)
        
        # Generate the spectrum
        spec = []
        num_supports = support_sorted_idxs.size()[0]
        print('num of supports: ' + str(num_supports))
        k = 0
        while k < num_supports:
            
            # Find the sequence with the lowest logit
            min_idx = torch.argmin(support_logits[k:])
            spec.append(support_sequences[k+min_idx.item()].tolist())
            k = k + min_idx.item() + 1
            
        spec = np.array(spec)
        print('spectrum size: ' + str(len(spec)))
        print(spec)
        
        # Save to file
        np.save('spec_RNA', spec)
                    
