import pickle
import sys

import torch

class NGram():
  def __init__(self, corpus, corpus_counts, type):
    self.corpus = corpus
    self.counts = corpus_counts
    self.type = type

  def prob(self, key, next):
    """
    Args:
      key (tuple): tuple of token ID's forming prior
      next (int): probability of next token
    """
    l = len(key)
    if self.type == "bigram":
      assert l == 1
      key = key[0]
    elif self.type == "trigram":
      assert l == 2
    elif self.type == "fourgram":
      assert l == 3
    elif self.type == "fivegram":
      assert l == 4
    elif self.type == "sixgram":
      assert l == 5
    elif self.type == "sevengram":
      assert l == 6
      
    count = 0
    if key in self.corpus:
      count = self.corpus[key].get(next, 0)
      total = sum(self.corpus[key].values())
      return count / total
    else:
      return -1
    
  def ntd(self, key):
    """
    Args:
      key (tuple): tuple of token ID's forming prior
    Returns:
      prob_tensor (torch.Tensor): (32000, ) of next token probabilities
    """
    if key in self.corpus:
      prob_tensor = torch.zeros(32000)
      total = sum(self.corpus[key].values())
      for next_token in self.corpus[key]:
        prob_tensor[next_token] = self.corpus[key][next_token] / total
      return prob_tensor
    else:
      return None

def make_models(ckpt_path, bigram, trigram, fourgram, fivegram, sixgram, sevengram):
  models = []
  if bigram:
    print("Making bigram...")
    with open(f"{ckpt_path}/b_d_final.pkl", "rb") as f:
        bigram = pickle.load(f)
    bigram_model = NGram(bigram, None, "bigram")    
    models.append(bigram_model)
    print(sys.getsizeof(bigram))
    
  if trigram:
    print("Making trigram...")
    with open(f"{ckpt_path}/t_d_final.pkl", "rb") as f:
        trigram = pickle.load(f)
    trigram_model = NGram(trigram, None, "trigram")
    models.append(trigram_model)
    print(sys.getsizeof(trigram))
    
  if fourgram:
    print("Making fourgram...")
    with open(f"{ckpt_path}/fo_d_final.pkl", "rb") as f:
        fourgram = pickle.load(f)
    fourgram_model = NGram(fourgram, None, "fourgram")
    models.append(fourgram_model)
    print(sys.getsizeof(fourgram))
  
  if fivegram:
    print("Making fivegram...")
    with open(f"{ckpt_path}/fi_d_final.pkl", "rb") as f:
        fivegram = pickle.load(f)
    fivegram_model = NGram(fivegram, None, "fivegram")
    models.append(fivegram_model)
    print(sys.getsizeof(fivegram))
      
  if sixgram:
    print("Making sixgram...")
    with open(f"{ckpt_path}/si_d_final.pkl", "rb") as f:
        sixgram = pickle.load(f)
    sixgram_model = NGram(sixgram, None, "sixgram")
    models.append(sixgram_model)
    print(sys.getsizeof(sixgram))

  if sevengram:
    print("Making sevengram...")
    with open(f"{ckpt_path}/se_d_final.pkl", "rb") as f:
        sevengram = pickle.load(f)
    sevengram_model = NGram(sevengram, None, "sevengram")
    models.append(sevengram_model)

  return models