"""
Use probe from pretrained model to create random function of board error. 
"""
import os
import numpy as np
import pickle
from contextlib import nullcontext
import torch
from model import GPTConfig, GPT
from othello_mamba.models.mamba.mamba import MambaConfig, Mamba
import pdb
import json
import tqdm
import torch.nn.functional as F

othello_world_path = os.path.expanduser("~/emergence/othello_world")
import sys
sys.path.append(othello_world_path)
from data.othello import OthelloBoardState

# ckpt_nums = [0, 1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000, 10000]
# ckpt_num = 5000
# ckpt_num = 700
ckpt_nums = [1000]

device = 'cuda'
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16'
batch_size = 1024

name = 'transformer-transfer-NTP-only-to-state-save-every-1K'; architecture='transformer'
# name = 'mamba-transfer-NTP-only-to-state-save-every-1K'; architecture='mamba'
# name = 'mamba2-transfer-NTP-only-to-state-save-every-1K'; architecture='mamba2'

out_dir = f'out/{name}/'
dataset = 'othello/synthetic-othello-1M'

torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)

same_board_fracs = []
same_next_board_fracs = []
next_move_fracs = []
next_next_move_fracs = []
one_move_in_common_fracs = []
one_next_move_in_common_fracs = []
reconstructed_is_subset_fracs = []
next_reconstructed_is_subset_fracs = []

for ckpt_num in ckpt_nums:
  print(f"Working on checkpoint {ckpt_num}")
  ckpt_path = os.path.join(out_dir, f'ckpt_{ckpt_num}.pt')
  checkpoint = torch.load(ckpt_path, map_location=device)
  if architecture == 'transformer':
      gptconf = GPTConfig(**checkpoint['model_args'])
      model = GPT(gptconf)
  else:
      conf = MambaConfig(**checkpoint['model_args'])
      model = Mamba(conf)
  state_dict = checkpoint['model']
  unwanted_prefix = '_orig_mod.'
  for k,v in list(state_dict.items()):
      if k.startswith(unwanted_prefix):
          state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
  model.load_state_dict(state_dict)
  model.to(device)
  model.eval()

  data_dir = os.path.join('data', dataset)
  data_config_file = os.path.join(data_dir, 'config.json')
  with open(data_config_file, 'r') as f:
      data_config = json.load(f)

  pad_id = data_config['pad_id']
  vocab_size = data_config['vocab_size']# + 1
  num_train_examples = data_config['num_train_examples']
  num_val_examples = data_config['num_val_examples']
  # num_val_examples = 25
  seq_len = data_config['seq_len']
  num_states = data_config['num_states']
  num_state_dimensions = data_config['num_state_dimensions']


  # Just do train split for now for simplicity
  def get_batch(start_ind, end_ind, split):
      if split == 'train':
          data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint8, mode='r', shape=(num_train_examples, seq_len))
          all_states = np.memmap(os.path.join(data_dir, 'train_states.bin'), dtype=np.uint8, mode='r', shape=(len(data), seq_len, num_state_dimensions))
      elif split == 'valid':
          data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint8, mode='r', shape=(num_val_examples, seq_len))
          all_states = np.memmap(os.path.join(data_dir, 'val_states.bin'), dtype=np.uint8, mode='r', shape=(len(data), seq_len, num_state_dimensions))
      x = torch.from_numpy((data[start_ind:end_ind, :-1]).astype(np.int64))
      y = torch.from_numpy((data[start_ind:end_ind, 1:]).astype(np.int64))
      states = torch.from_numpy(all_states[start_ind:end_ind, :, :].astype(np.int64))
      x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
      states = states.pin_memory().to(device, non_blocking=True)
      return x, y, states

  # Only get predictions for valid
  all_preds = []
  all_states = []
  bar = tqdm.tqdm(range(0, num_val_examples, batch_size))
  with torch.no_grad():
    for i in bar:
        x, y, states = get_batch(i, i + batch_size, 'valid')
        b, t = x.size()
        with ctx:
            _, _, _, state_preds = model(x, states=states)
            states = states[:, :-1]
            top_preds = state_preds.argmax(-1)
            all_preds.append(top_preds.cpu().numpy())
            all_states.append(states.cpu().numpy())


  all_preds = np.concatenate(all_preds, axis=0).reshape(-1, 64)
  all_states = np.concatenate(all_states, axis=0).reshape(-1, 64)

  all_states.shape
  all_preds.shape

  tile_accuracy = (all_states == all_preds).mean()
  board_accuracy = ((all_states == all_preds).mean(-1) == 1).mean()

  total = 0
  total_same_next_moves = 0
  total_same_boards = 0
  total_empty = 0
  total_one_move_in_common = 0
  total_reconstructed_is_subset = 0

  total_same_next_next_moves = 0
  total_same_next_boards = 0
  total_next_empty = 0
  total_one_next_move_in_common = 0
  total_next_reconstructed_is_subset = 0


  bar = tqdm.tqdm(range(len(all_states)))
  good_examples = []
  for ind in bar:
    reconstructed_board = (all_preds[ind] - 1).reshape(8, 8)
    reconstructed_game = OthelloBoardState()
    reconstructed_game.state = np.copy(reconstructed_board)
    #
    true_board = (all_states[ind] - 1).reshape(8, 8)
    true_game = OthelloBoardState()
    true_game.state = np.copy(true_board)
    # Heuristic: 
    next_hand_color = 1 if np.sum(np.abs(true_board) == 1) % 2 == 0 else -1
    reconstructed_game.next_hand_color = next_hand_color
    true_game.next_hand_color = next_hand_color
    #
    valid_moves_reconstructed = reconstructed_game.get_valid_moves()
    valid_moves_true = true_game.get_valid_moves()
    same_board = np.all(reconstructed_board == true_board)
    same_next_moves = set(valid_moves_reconstructed) == set(valid_moves_true)
    at_least_one_move_in_common = len(set(valid_moves_reconstructed) & set(valid_moves_true)) > 0
    is_subset = set(valid_moves_reconstructed) | set(valid_moves_true) == set(valid_moves_true)
    if same_next_moves:
      total_same_next_moves += 1
    if same_board:
      total_same_boards += 1
    if len(valid_moves_true) == 0:
      total_empty += 1
    if at_least_one_move_in_common:
      total_one_move_in_common += 1
    if is_subset: 
      total_reconstructed_is_subset += 1
    if (same_next_moves and not same_board and len(valid_moves_true) > 4 and 
        np.sum(true_board != reconstructed_board) > 8 
        and np.sum(true_board != 0) < 50):
      good_examples.append(ind)
    total += 1
  
  same_board_frac = total_same_boards / total
  next_move_frac = total_same_next_moves / total
  one_move_in_common_frac = total_one_move_in_common / total
  reconstructed_is_subset_frac = total_reconstructed_is_subset / total

  # same_next_board_frac = total_same_next_boards / total
  # next_next_move_frac = total_same_next_next_moves / total
  # one_next_move_in_common_frac = total_one_next_move_in_common / total
  # next_reconstructed_is_subset_frac = total_next_reconstructed_is_subset / total

  same_board_fracs.append(same_board_frac)
  next_move_fracs.append(next_move_frac)
  one_move_in_common_fracs.append(one_move_in_common_frac)
  reconstructed_is_subset_fracs.append(reconstructed_is_subset_frac)
  # same_next_board_fracs.append(same_next_board_frac)
  # next_next_move_fracs.append(next_next_move_frac)
  # one_next_move_in_common_fracs.append(one_next_move_in_common_frac)
  # next_reconstructed_is_subset_fracs.append(next_reconstructed_is_subset_frac)

print(f"iters = {ckpt_nums}")
print(f"{architecture}_next_move_fracs = {list(np.array(next_move_fracs).astype(np.float16))}")
print(f"{architecture}_one_move_in_common_fracs = {list(np.array(one_move_in_common_fracs).astype(np.float16))}")
print(f"{architecture}_reconstructed_is_subset_fracs = {list(np.array(reconstructed_is_subset_fracs).astype(np.float16))}")
print(f"{architecture}_same_board_fracs = {list(np.array(same_board_fracs).astype(np.float16))}")

# print(f"next_next_move_fracs = {list(np.array(next_next_move_fracs).astype(np.float16))}")
# print(f"one_next_move_in_common_fracs = {list(np.array(one_next_move_in_common_fracs).astype(np.float16))}")
# print(f"next_reconstructed_is_subset_fracs = {list(np.array(next_reconstructed_is_subset_fracs).astype(np.float16))}")
# print(f"same_next_board_fracs = {list(np.array(same_next_board_fracs).astype(np.float16))}")


pdb.set_trace()

example_ind = 0
(all_preds[good_examples[example_ind]] - 1).reshape(8, 8)- (all_states[good_examples[example_ind]] - 1).reshape(8, 8)
# blacks turn
