# -----------------------------------------------------------------------------
# PyTorch nn.Module definitions for the GPT-2 model
import torch
from torch import nn
import torch.nn.functional as F
import tiktoken
from dataclasses import dataclass
import numpy as np
from typing import List


EOS_TOKEN = 50256
NEWLINE_TOKEN = 198


class Rotary(torch.nn.Module):

    def __init__(self, dim, base=10000, use_cache=True):
        super().__init__()
        self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.seq_len_cached = None
        self.cos_cached = None
        self.sin_cached = None
        self.use_cache = use_cache

    def forward(self, x):
        if self.use_cache:
            seq_len = x.shape[1]
            if seq_len != self.seq_len_cached:
                self.seq_len_cached = seq_len
                t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
                freqs = torch.outer(t, self.inv_freq).to(x.device)
                self.cos_cached = freqs.cos().bfloat16()
                self.sin_cached = freqs.sin().bfloat16()
            return self.cos_cached[None, :, None, :], self.sin_cached[None, :, None, :]
        else:
            seq_len = x.shape[1]
            t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
            freqs = torch.outer(t, self.inv_freq).to(x.device)
            return freqs.cos().bfloat16()[None, :, None, :], freqs.sin().bfloat16()[None, :, None, :]


def apply_rotary_emb(x, cos, sin):
    assert x.ndim == 4  # multihead attention
    d = x.shape[3] // 2
    x1 = x[..., :d]
    x2 = x[..., d:]
    y1 = x1 * cos + x2 * sin
    y2 = x1 * (-sin) + x2 * cos
    return torch.cat([y1, y2], 3).type_as(x)


class CausalSelfAttention(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.head_dim = self.n_embd // self.n_head
        assert self.n_embd % self.n_head == 0, f"n_embd {self.n_embd} must be divisible by n_head {self.n_head}"
        self.c_q = nn.Linear(self.n_embd, self.n_embd, bias=False)
        self.c_k = nn.Linear(self.n_embd, self.n_embd, bias=False)
        self.c_v = nn.Linear(self.n_embd, self.n_embd, bias=False)
        self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
        self.c_proj.weight.data.zero_()
        self.rotary = Rotary(self.head_dim)

    def create_packing_mask(self, x, eos_token=50256):
        B, T = x.size()
        device = x.device

        # Identify EOS tokens
        eos_mask = (x == eos_token).to(torch.int32)  # Shape: (B, T)

        # Compute cumulative sum to assign segment IDs
        # As our sequence begins with eos token, we can use the cumulative sum to assign segment IDs
        segment_ids = torch.cumsum(eos_mask, dim=1)

        # Create mask where tokens can attend to others in the same segment
        mask = segment_ids.unsqueeze(1) == segment_ids.unsqueeze(2)  # Shape: (B, T, T)

        # Create causal mask to prevent attending to future tokens
        causal_mask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=device))
        combined_mask = mask & causal_mask.unsqueeze(0)  # Shape: (B, T, T)
        return combined_mask



    def forward(self, x, idx=None):
        B, T, C = x.size()
        #TODO: could optimize this by sharing the mask across layers
        # Create attention mask if indices are provided
        if idx is not None:
            attention_mask = self.create_packing_mask(idx)
        else:
            attention_mask = None

        q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
        k = self.c_k(x).view(B, T, self.n_head, self.head_dim)
        v = self.c_v(x).view(B, T, self.n_head, self.head_dim)
        
        cos, sin = self.rotary(q)
        q, k = F.rms_norm(q, (q.size(-1),)), F.rms_norm(k, (k.size(-1),))
        q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
        # Apply attention with the bio mask
        y = F.scaled_dot_product_attention(
            q.transpose(1, 2),
            k.transpose(1, 2),
            v.transpose(1, 2),
            attn_mask=attention_mask.unsqueeze(1),
            is_causal=attention_mask is None  # Only use built-in causal mask if no bio mask
        )

        y = y.transpose(1, 2).contiguous().view_as(x)
        y = self.c_proj(y)
        return y


class MLP(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
        self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
        self.c_proj.weight.data.zero_()  # zero init suggested by @Grad62304977

    def forward(self, x):
        x = self.c_fc(x)
        x = F.relu(
            x
        ).square()  # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977
        x = self.c_proj(x)
        return x


class Block(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.attn = CausalSelfAttention(config)
        self.mlp = MLP(config)

    def forward(self, x):
        # Get idx from attention layer if it exists
        idx = getattr(self.attn, 'idx', None)
        x = x + self.attn(F.rms_norm(x, (x.size(-1),)), idx=idx)
        x = x + self.mlp(F.rms_norm(x, (x.size(-1),)))
        return x


# -----------------------------------------------------------------------------
# The main GPT-2 model

@dataclass
class GPTConfig:
    vocab_size: int = 50304
    n_layer: int = 12
    n_head: int = 6
    n_embd: int = 768


class GPT(nn.Module):

    def __init__(self, config):
        super().__init__()

        self.transformer = nn.ModuleDict(
            dict(
                wte=nn.Embedding(config.vocab_size, config.n_embd),
                h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            )
        )
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        self.transformer.wte.weight = (
            self.lm_head.weight
        )  # https://paperswithcode.com/method/weight-tying
        
    @property
    def device(self):
        return self.lm_head.weight.device

    def forward(self, idx, targets=None, return_logits=True):
        # forward the GPT model itself
        x = self.transformer.wte(idx)  # token embeddings of shape (b, t, n_embd)
        
        # Pass through transformer blocks with idx for attention masking
        for block in self.transformer.h:
            # Pass idx to the attention layer
            block.attn.idx = idx  # Store idx temporarily
            x = block(x)
            block.attn.idx = None  # Clean up

        x = F.rms_norm(x, (x.size(-1),))
        if targets is not None:
            logits = self.lm_head(x)
            logits = logits.float()
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)), targets.contiguous().view(-1), ignore_index=-100
            )
        else:
            logits = self.lm_head(x[:, [-1], :])
            logits = logits.float()
            loss = None

        if not return_logits:
            logits = None
        return logits, loss

    def get_num_params(self, non_embedding=True):
        """
        Return the number of parameters in the model.
        For non-embedding count (default), the position embeddings get subtracted.
        The token embeddings would too, except due to the parameter sharing these
        params are actually used as weights in the final layer, so we include them.
        """
        n_params = sum(p.numel() for p in self.parameters())
        if non_embedding:
            n_params -= self.transformer.wte.weight.numel()
        return n_params

    
    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None, stop_token=50256):
        """
        Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
        the sequence max_new_tokens times, feeding the predictions back into the model each time.
        Most likely you'll want to make sure to be in model.eval() mode of operation for this.
        """
        for _ in range(max_new_tokens):
            # forward the model to get the logits for the index in the sequence
            logits, _ = self(idx)
            if temperature == 0:
                idx_next = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
            else:
                # pluck the logits at the final step and scale by desired temperature
                logits = logits[:, -1, :] / temperature #the last word
                # optionally crop the logits to only the top k options
                if top_k is not None:
                    v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                    logits[logits < v[:, [-1]]] = -float('Inf')
                # apply softmax to convert logits to (normalized) probabilities
                probs = F.softmax(logits, dim=-1)
                # sample from the distribution
                idx_next = torch.multinomial(probs, num_samples=1)
            if idx_next == stop_token:
                break
            # append sampled index to the running sequence and continue
            idx = torch.cat((idx, idx_next), dim=1)
        return idx
    
    @classmethod
    def from_pretrained(cls, path, device='cpu'):
        """
        Load a saved model from a file.

        Args:
            path (str): The path to the saved model file.
            device: The device to load the model onto (e.g., 'cpu' or 'cuda').

        Returns:
            GPT: The loaded model.

        Example:
            >>> model_path = "path/to/saved_model.pt"
            >>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            >>> loaded_model = GPT.from_pretrained(model_path, device)
        """
        checkpoint = torch.load(path, map_location=device, weights_only=True)
        model_state_dict = checkpoint['model'] if 'model' in checkpoint else checkpoint
        
        if 'config' not in checkpoint:
            # Try to infer model parameters from the state dict
            config = {}
            for key in model_state_dict.keys():
                if 'wte.weight' in key:
                    config['vocab_size'] = model_state_dict[key].shape[0]
                elif 'wpe.weight' in key:
                    config['block_size'] = model_state_dict[key].shape[0]
                elif '.attn.' in key and '.weight' in key:
                    if 'attn.c_attn.weight' in key:
                        config['n_embd'] = model_state_dict[key].shape[0] // 3
                    elif 'attn.c_proj.weight' in key:
                        config['n_embd'] = model_state_dict[key].shape[1]
                elif '.mlp.' in key and '.weight' in key:
                    if 'mlp.c_fc.weight' in key:
                        config['n_embd'] = model_state_dict[key].shape[1]
            
            # Count the number of transformer blocks
            # breakpoint()
            config['n_layer'] = len([k for k in model_state_dict.keys() if 'h.' in k and '.attn.c_q.weight' in k])
            
            # Infer n_head if possible
            if 'n_embd' in config:
                for key in model_state_dict.keys():
                    if '.attn.bias' in key:
                        n_head = model_state_dict[key].shape[0]
                        if n_head * (config['n_embd'] // n_head) == config['n_embd']:
                            config['n_head'] = n_head
                            break
            print("Warning, this code might produce wrong n_embd")
        else:
            config = checkpoint['config']
        
        # Create a GPTConfig object from the inferred configuration
        gpt_config = GPTConfig(**config)
        print(gpt_config)
        
        model = cls(gpt_config)
        model.load_state_dict(model_state_dict, strict=True)
        model = model.to(device)
        return model
    
    @property
    def config(self):
        return {
            'vocab_size': self.transformer.wte.weight.shape[0],
            'n_embd': self.transformer.wte.weight.shape[1],
            'n_layer': len(self.transformer.h),
            'n_head': self.transformer.h[0].attn.n_head if self.transformer.h else 0
        }
    
    def get_params(self, include_lm_head=True):
        count = 0
        for name, param in self.named_parameters():
            if not include_lm_head or not name.endswith('wte.weight'):
                count += param.numel()
        return count
    
    def disable_rotary_cache(self):
        for block in self.transformer.h:
            block.attn.rotary.use_cache = False
    
def tokenize(s, enc):
    # tokenizes a single document and returns a numpy array of uint16 tokens
    eot = enc._special_tokens['<|endoftext|>'] # end of text token
    tokens = [eot] # the special <|endoftext|> token delimits all documents
    tokens.extend(enc.encode_ordinary(s))
    tokens_np = np.array(tokens)
    assert (0 <= tokens_np).all() and (tokens_np < 2**16).all(), "token dictionary too large for uint16"
    tokens_np_uint16 = tokens_np.astype(np.uint16)
    return tokens_np_uint16

def inference(model, input_text: str, tokenizer=None, max_new_tokens=100, **kwargs):
    """
    Perform inference using a GPT model.

    Args:
        model: Either a path to a saved model file or a GPT model instance.
        input_text (str): The input text to start the generation from.
        tokenizer: Either a tokenizer instance or None. If None, load from tiktoken.
        max_new_tokens (int): Maximum number of new tokens to generate.
        **kwargs: Additional keyword arguments for text generation.

    Returns:
        str: The generated text.
    """
    
    # Load the model if a path is provided
    if isinstance(model, str):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model = GPT.from_pretrained(model, device)
    elif isinstance(model, GPT):
        pass
    else:
        raise ValueError("model must be either a path to a saved model or a GPT instance")

    model.eval()
    model.disable_rotary_cache()
    
    # Use provided tokenizer or load from tiktoken
    if tokenizer is None:
        tokenizer = tiktoken.get_encoding("gpt2")

    # Encode the input text using the tokenize function
    input_ids = tokenize(input_text, tokenizer)
    input_ids = torch.tensor(input_ids, dtype=torch.long, device=model.device).unsqueeze(0)

    # Generate text
    with torch.no_grad():
        output_ids = model.generate(input_ids, max_new_tokens=max_new_tokens, **kwargs)

    # Decode the generated ids back to text
    output_text = tokenizer.decode(output_ids[0].tolist())
    # exclude the input text from the output
    output_text = output_text[len(input_text+'<|endoftext|>'):]

    return output_text

def eval_loss(model, texts, tokenizer=None, special_texts: List[List[str]]=None, ignore_first_token=False):
    """
    Evaluate the loss of given texts on the model, distinguishing between keyword and non-keyword parts

    Args:
        model: The GPT model instance used for evaluation
        texts: The texts to evaluate, can be a single string or a list of strings
        tokenizer: The tokenizer instance, if None, load from tiktoken
        special_texts: List of keyword lists, each corresponding to a text in texts
        ignore_first_token: Whether to ignore the first token

    Returns:
        tuple: A tuple containing:
            - Total loss value
            - Loss value for keyword parts (None if no keywords or not provided)
            - Loss value for non-keyword parts (equals total loss if no keywords)
            If multiple texts are provided, return lists of corresponding values
    """
    device = model.device
    
    # 1. Initialize tokenizer
    if tokenizer is None:
        tokenizer = tiktoken.get_encoding("gpt2")
    
    # 2. Handle single text input
    single_text = isinstance(texts, str)
    if single_text:
        texts = [texts]
    
    # 3. Encode all texts
    encoded_texts = [tokenize(text, tokenizer) for text in texts]
    max_len = max(len(tokens) for tokens in encoded_texts)
    
    # 4. Pad texts and convert to tensor
    padded_texts = [np.concatenate([tokens, np.array([-100] * (max_len - len(tokens)))]) for tokens in encoded_texts]
    input_ids = torch.tensor(np.array(padded_texts), dtype=torch.long, device=device)
    
    # 5. Prepare input and target IDs
    target_ids = input_ids[:, 1:].clone()
    input_ids = input_ids[:, :-1]
    input_ids[input_ids == -100] = EOS_TOKEN  # Replace padding values with EOS token
    
    # 6. Compute model output
    model.eval()
    with torch.no_grad():
        logits, _ = model(input_ids, target_ids)
        log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
        
        total_losses = []
        keyword_losses = []
        non_keyword_losses = []
        
        for i in range(len(texts)):
            seq_len = len(encoded_texts[i])
            token_log_probs = log_probs[i, range(seq_len-1), target_ids[i][:seq_len-1]]
            if ignore_first_token:
                token_log_probs = token_log_probs[1:]
            # Compute total loss
            total_loss = sum(token_log_probs).item() / (seq_len-1)
            total_losses.append(total_loss)
            
            # Handle keyword parts
            if special_texts is not None and i < len(special_texts) and special_texts[i]:
                # Get keyword token mask
                keyword_mask = _get_keyword_mask(
                    encoded_texts[i], 
                    special_texts[i], 
                    tokenizer
                )[1:seq_len]  # Ensure consistent with target length
                if ignore_first_token:
                    keyword_mask = keyword_mask[1:]
                
                if keyword_mask.any():
                    # Compute keyword loss
                    keyword_loss = token_log_probs[keyword_mask].mean().item()
                    
                    # Compute non-keyword loss
                    non_keyword_loss = token_log_probs[~keyword_mask].mean().item() if (~keyword_mask).any() else 0.0
                    
                    keyword_losses.append(keyword_loss)
                    non_keyword_losses.append(non_keyword_loss)
                else:
                    keyword_losses.append(None)
                    non_keyword_losses.append(total_loss)
            else:
                keyword_losses.append(None)
                non_keyword_losses.append(total_loss)
    
    # 7. Return results
    if single_text:
        return total_losses[0], keyword_losses[0], non_keyword_losses[0]
    else:
        return total_losses, keyword_losses, non_keyword_losses

def _get_keyword_mask(encoded_text, keywords, tokenizer):
    """
    Helper function: Generate a boolean mask for keyword positions

    Args:
        encoded_text: Encoded text (list of tokens)
        keywords: List of keywords
        tokenizer: The tokenizer used

    Returns:
        torch.BoolTensor: Boolean mask, True for keyword positions
    """
    mask = torch.zeros(len(encoded_text), dtype=torch.bool)
    text_tokens = encoded_text
    
    for keyword in keywords:
        # Encode keyword
        keyword_tokens = tokenize(keyword, tokenizer)[1:]  # Account for <|endoftext|> token
        kw_len = len(keyword_tokens)
        
        # Search for keyword in text
        for i in range(len(text_tokens) - kw_len + 1):
            if (text_tokens[i:i+kw_len] == keyword_tokens).all():
                mask[i:i+kw_len] = True
    
    return mask


if __name__ == "__main__":
    PATH = "logs/345ab7f3-1ba5-4518-b02d-ff1a7f8c3b64/state_step005100.pt"
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = GPT.from_pretrained(PATH, device)
    # Initialize the tokenizer
    enc = tiktoken.get_encoding("gpt2")
    print(inference(model, "Jade Emiliano Long entered life on", tokenizer=enc))