import numpy as np
import torch
from lm import nn_log_probs
from lm import nn_next_token_probs
from lm import nn_rep_from_hidden_states

def trim_cache(full_cache):
    return [(k[:, :, :-1, :], v[:, :, :-1, :]) for (k, v) in full_cache]

def expand_past_key_values(past_key_values, batch_size):
    return [(k.expand(batch_size, -1, -1, -1), v.expand(batch_size, -1, -1, -1)) for (k, v) in past_key_values]

class LMPolicy:
    def __init__(self, model, tokenizer, max_new_tokens, device):
        self.model = model
        self.tokenizer = tokenizer
        self.max_new_tokens = max_new_tokens
        self.vocab_size = model.config.vocab_size
        self.kv_cache = None
        self.kv_sequence = []
        self.max_len = 500 + self.max_new_tokens
        self.kv_reps = torch.zeros((self.max_len, model.config.hidden_size)).to(device)
        #print("KV reps", self.kv_reps.shape)
        self.device = device

    def is_complete(self, prefix, seq):
        new_tokens = len(seq) - len(prefix)
        last_token = seq[-1]
        assert new_tokens < self.max_new_tokens
        return (new_tokens == self.max_new_tokens-1)
    
    def next_token_probs_and_rep_unoptimized(self, sequence, layer):
        with torch.no_grad():
            output = self.model(torch.tensor([sequence]).to(self.device), output_hidden_states = True)
            preds = output.logits[0,-1,:].cpu()
            #print("Shape of model logits", preds.shape)
            probs = torch.softmax(preds, dim=-1)
            #print("Next-token probabilities:", probs)
            rep = output.hidden_states[layer][:, -1]
            return probs, rep

    def next_token_probs_and_rep(self, sequence, layer, which_tokens, pool_range):
        with torch.no_grad():
            while (self.kv_sequence != sequence[:len(self.kv_sequence)]) or self.kv_sequence == sequence:
                self.kv_sequence.pop()
                self.kv_cache = trim_cache(self.kv_cache)
            if self.kv_sequence == []:
                output = self.model(torch.tensor([sequence]).to(self.device), use_cache = True, output_hidden_states = True)
                preds = output.logits[0,-1,:]
            else:
                output = self.model(torch.tensor([sequence[len(self.kv_sequence):]]).to(self.device), past_key_values = self.kv_cache, use_cache = True, output_hidden_states = True)
                preds = output.logits[0,-1,:]
            hidden_state = output.hidden_states[layer]
            #print("Sequence length", len(sequence))
            #print("KV sequence length", len(self.kv_sequence))
            #print("Hidden state shape", hidden_state.shape)
            
            for i in range(len(sequence) - len(self.kv_sequence)):
                self.kv_reps[i + len(self.kv_sequence)] = hidden_state[0, i]
            self.kv_cache = output.past_key_values
            self.kv_sequence = sequence.copy()

            #print("Shape of model logits", preds.shape)
            probs = torch.softmax(preds, dim=-1)
            #print("Next-token probabilities:", probs)
            if which_tokens == 'last':
                rep = hidden_state[:, -1]
            elif which_tokens[:4] == 'mean':
                rep_tensor = self.kv_reps[pool_range]
                #print("Pool range", pool_range)
                #print("Rep tensor shape", rep_tensor.shape)
                rep = torch.mean(rep_tensor, axis=0)
            else:
                assert NotImplementedError("which_tokens")
            #print(rep.shape, flush=True)
            return probs, rep

    def next_token_probs(self, sequence, layer=-1):
        with torch.no_grad():
            while (self.kv_sequence != sequence[:len(self.kv_sequence)]) or self.kv_sequence == sequence:
                self.kv_sequence.pop()
                self.kv_cache = trim_cache(self.kv_cache)
            if self.kv_sequence == []:
                output = self.model(torch.tensor([sequence]).to(self.device), use_cache = True, output_hidden_states = True)
                preds = output.logits[0,-1,:]
            else:
                output = self.model(torch.tensor([sequence[len(self.kv_sequence):]]).to(self.device), past_key_values = self.kv_cache, use_cache = True, output_hidden_states = True)
                preds = output.logits[0,-1,:]
            #print(len(self.kv_sequence), len(sequence))
            hidden_state = output.hidden_states[layer]
            #print(hidden_state.shape)
            for i in range(len(sequence) - len(self.kv_sequence)):
                #print(hidden_state[0, i].shape)
                #print(self.kv_reps.shape)
                self.kv_reps[i + len(self.kv_sequence)] = hidden_state[0, i]
            self.kv_cache = output.past_key_values
            self.kv_sequence = sequence.copy()
            probs = torch.softmax(preds, dim=-1)
            #print("Next-token probabilities:", probs)
            return probs

    def next_token(self, sequence, layer_for_vf):
        token_probs = self.piref.next_token_probs(sequence, layer_for_vf)
        chosen_token = torch.multinomial(token_probs, num_samples = 1, replacement=True)[0]
        return chosen_token

    def next_k_tokens(self, sequence, k, layer_for_vf):
        L = []
        for i in range(k):
            L.append(self.next_token(sequence + L, layer_for_vf))
        return L

        
    def compute_batch_reps_unoptimized(self, base_sequence, sequences, layer):
        with torch.no_grad():
            #while self.kv_sequence != base_sequence[:len(self.kv_sequence)]:
            #    self.kv_sequence.pop()
            #    self.kv_cache = trim_cache(self.kv_cache)
            #if self.kv_sequence == []:
            #    base_output = self.model(torch.tensor([base_sequence],dtype=torch.int64).cuda(), output_hidden_states=True, use_cache=True)
            #else:
            #    base_output = self.model(torch.tensor([base_sequence[len(self.kv_sequence):]],dtype=torch.int64).cuda(), past_key_values = self.kv_cache, output_hidden_states=True, use_cache=True)
            #base_rep = base_output.hidden_states[layer][:, -1]
            #self.kv_cache = base_output.past_key_values
            #self.kv_sequence = base_sequence

            batch_tensor = torch.tensor(sequences, dtype=torch.int64).to(self.device)
            #print(batch_tensor.shape)
            batch_size = batch_tensor.shape[0]
            batch_output = self.model(batch_tensor, output_hidden_states=True)
            #print(type(batch_output.hidden_states))
            #print(len(batch_output.hidden_states))
            #print(batch_output.hidden_states[layer].shape)
            batch_rep = batch_output.hidden_states[layer][:, -1]
            #print(batch_rep.shape)
        return batch_rep #np.array(nn_intermediates(model, sequences, batch_size=32, positions='last', layer=layer))

    def compute_batch_blocks_and_reps(self, sequence, batch_size, block_length, layer, which_tokens, pool_range, expected_token_probs = None):
        with torch.no_grad():
            while (self.kv_sequence != sequence[:len(self.kv_sequence)]) or (self.kv_sequence == sequence):
                self.kv_sequence.pop()
                self.kv_cache = trim_cache(self.kv_cache)
            base_seq_tensor = torch.tensor(sequence).to(self.device).unsqueeze(0).expand(batch_size, -1)
            seq_tensor = torch.cat([base_seq_tensor, torch.zeros((batch_size, block_length), dtype=torch.int64).to(self.device)], axis=1)

            prev_rep_tensor = self.kv_reps[:len(self.kv_sequence)].unsqueeze(0).expand(batch_size, -1, -1)
            rep_tensor = torch.cat([prev_rep_tensor, torch.zeros((batch_size, block_length + len(sequence) - len(self.kv_sequence), self.model.config.hidden_size)).to(self.device)], axis=1)
            interim_cache = expand_past_key_values(self.kv_cache, batch_size)
            interim_cached_length = len(self.kv_sequence)
            for i in range(block_length):
                input_tensor = seq_tensor[:, interim_cached_length:len(sequence)+i]
                #print(self.kv_sequence)
                #print(input_tensor.shape)
                if interim_cached_length == 0:
                    output = self.model(input_tensor, use_cache = True, output_hidden_states = True)
                else:
                    output = self.model(input_tensor, past_key_values = interim_cache, use_cache = True, output_hidden_states = True)
                rep_tensor[:, interim_cached_length:len(sequence)+i, :] = output.hidden_states[layer]
                
                preds = output.logits[:,-1,:]
                probs = torch.softmax(preds, dim=-1)
                #if expected_token_probs is not None:
                #    print(i, len(self.kv_sequence), len(sequence))
                #    print(np.linalg.norm(probs.cpu() - expected_token_probs.cpu(), ord = 1))
                next_tokens = torch.multinomial(probs, num_samples = 1, replacement=True)
                seq_tensor[:, len(sequence)+i] = next_tokens.squeeze(1)
                
                interim_cache = output.past_key_values
                interim_cached_length = len(sequence)+i
            
            assert which_tokens == 'mean'
            pooled_rep_tensor = torch.mean(rep_tensor[:, pool_range], axis=1) 
            
            return seq_tensor[:, len(sequence):], pooled_rep_tensor

    def compute_batch_reps(self, base_sequence, sequences, layer, which_tokens, pool_range):
        assert self.kv_sequence == base_sequence
        continuations = []
        for seq in sequences:
            assert seq[:-1] == base_sequence
            continuations.append([seq[-1]])
        with torch.no_grad():
            #while self.kv_sequence != base_sequence[:len(self.kv_sequence)]:
            #    self.kv_sequence.pop()
            #    self.kv_cache = trim_cache(self.kv_cache)
            #if self.kv_sequence == []:
            #    base_output = self.model(torch.tensor([base_sequence],dtype=torch.int64).cuda(), output_hidden_states=True, use_cache=True)
            #else:
            #    base_output = self.model(torch.tensor([base_sequence[len(self.kv_sequence):]],dtype=torch.int64).cuda(), past_key_values = self.kv_cache, output_hidden_states=True, use_cache=True)
            #base_rep = base_output.hidden_states[layer][:, -1]
            #self.kv_cache = base_output.past_key_values
            #self.kv_sequence = base_sequence

            batch_tensor = torch.tensor(continuations, dtype=torch.int64).to(self.device)
            #print(batch_tensor.shape)
            batch_size = batch_tensor.shape[0]
            batch_output = self.model(batch_tensor, past_key_values=expand_past_key_values(self.kv_cache, batch_size), output_hidden_states=True, use_cache=True)
            #print(len(base_sequence))
            #print(batch_tensor.shape)
            #print(batch_output.hidden_states[layer].shape,flush=True)
            intermediate = batch_output.hidden_states[layer]
            if which_tokens == 'last':
                batch_rep = intermediate[:, -1]
            elif which_tokens[:4] == 'mean':
                current_rep = intermediate[:, -1].unsqueeze(1)
                prev_rep_tensor = self.kv_reps[:len(self.kv_sequence)].unsqueeze(0).expand(current_rep.size(0), -1, -1)
                #print("Shape in batch reps")
                #print("KV shape",self.kv_reps.shape)
                #print(prev_rep_tensor.shape)
                #print(current_rep.shape, flush=True)
                #print(prev_rep_tensor.shape)
                #print(current_rep.shape)
                #print(pool_range)
                C = torch.cat([prev_rep_tensor, current_rep], axis=1)
                batch_rep = torch.mean(C[:, pool_range], axis=1)
        return batch_rep #np.array(nn_intermediates(model, sequences, batch_size=32, positions='last', layer=layer))
