import os
import numpy as np
import pickle
from contextlib import nullcontext
import torch
from model import GPTConfig, GPT
from othello_mamba.models.mamba.mamba import MambaConfig, Mamba
import pdb
import json
import tqdm
from collections import defaultdict
from sklearn.model_selection import train_test_split

device = 'cuda'
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16'
batch_size = 1024

pretrain = 'NTP'  # 'NTP' or 'state' or 'random
architecture = 'transformer' # 'transformer' or 'mamba' or 'mamba2'
othello_type = 'synthetic' # synthetic or championship

num_examples = 500

othello_slug = 'synthetic-1M' if othello_type == 'synthetic' else 'championship'

architecture_name = architecture + '-' if architecture != 'transformer' else ''

if othello_type == 'synthetic':
    dataset = 'othello/synthetic-othello-1M'
elif othello_type == 'championship':
    dataset = 'othello/championship-othello'

save_dir_from_state = f'data/{dataset}/{num_examples}-examples-white-noise-{pretrain}-{architecture}'
save_dir_from_sequence = f'data/{dataset}/{num_examples}-examples-white-noise-{pretrain}-{architecture}-seq'

if num_examples == 500:
    seeds = [1, 2, 3, 4, 5]
else:
    seeds = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
ckpt_num = 1000

all_train_preds = []
all_valid_preds = []

for seed in seeds: 
  print(f"Working on seed {seed}")
  if pretrain == 'NTP':
    name = f'{architecture_name}{num_examples}-examples-{othello_slug}-transfer-NTP-only-to-white-noise-seed-{seed}'
  elif pretrain == 'state':
    name = f'{architecture_name}{num_examples}-examples-{othello_slug}-transfer-state-only-to-white-noise-seed-{seed}'
  elif pretrain == 'random':
    name = f'{architecture_name}{num_examples}-examples-{othello_slug}-transfer-random-init-to-white-noise-seed-{seed}'

  out_dir = f'out/{name}/'


  torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
  torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
  device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
  ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
  ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)

  ckpt_path = os.path.join(out_dir, f'ckpt_{ckpt_num}.pt')
  checkpoint = torch.load(ckpt_path, map_location=device)
  if architecture == 'transformer':
      gptconf = GPTConfig(**checkpoint['model_args'])
      model = GPT(gptconf)
  else:
      conf = MambaConfig(**checkpoint['model_args'])
      model = Mamba(conf)
  state_dict = checkpoint['model']
  unwanted_prefix = '_orig_mod.'
  for k,v in list(state_dict.items()):
      if k.startswith(unwanted_prefix):
          state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)

  model.load_state_dict(state_dict)
  model.to(device)
  model.eval()

  data_dir = os.path.join('data', dataset)
  data_config_file = os.path.join(data_dir, 'config.json')
  with open(data_config_file, 'r') as f:
      data_config = json.load(f)

  num_state_dimensions = data_config['num_state_dimensions']

  pad_id = data_config['pad_id']
  vocab_size = data_config['vocab_size']# + 1
  num_train_examples = data_config['num_train_examples']
  num_val_examples = data_config['num_val_examples']
  seq_len = data_config['seq_len']
  num_state_dimensions = data_config['num_state_dimensions']

  def get_batch(start_ind, end_ind, split):
      if split == 'train':
          data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint8, mode='r', shape=(num_train_examples, seq_len))
          all_states = np.memmap(os.path.join(data_dir, 'train_states.bin'), dtype=np.uint8, mode='r', shape=(len(data), seq_len, num_state_dimensions))
      elif split == 'valid':
          data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint8, mode='r', shape=(num_val_examples, seq_len))
          all_states = np.memmap(os.path.join(data_dir, 'val_states.bin'), dtype=np.uint8, mode='r', shape=(len(data), seq_len, num_state_dimensions))
      x = torch.from_numpy((data[start_ind:end_ind, :-1]).astype(np.int64))
      y = torch.from_numpy((data[start_ind:end_ind, 1:]).astype(np.int64))
      states = torch.from_numpy(all_states[start_ind:end_ind, :, :].astype(np.int64))
      x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
      states = states.pin_memory().to(device, non_blocking=True)
      return x, y, states

  train_preds = []
  bar = tqdm.tqdm(range(0, num_train_examples, batch_size))
  with torch.no_grad():
    for i in bar:
        x, y, _ = get_batch(i, i + batch_size, 'train')
        b, t = x.size()
        with ctx:
            _, _, _, state_preds = model(x)
            train_preds.append(state_preds.softmax(-1)[:, :, 1].float().cpu().numpy())

  valid_preds = []
  bar = tqdm.tqdm(range(0, num_val_examples, batch_size))
  with torch.no_grad():
    for i in bar:
        x, y, states = get_batch(i, i + batch_size, 'valid')
        b, t = x.size()
        with ctx:
            _, _, _, state_preds = model(x)
            valid_preds.append(state_preds.softmax(-1)[:, :, 1].float().cpu().numpy())

  train_preds = np.concatenate(train_preds, axis=0)#.reshape(-1)
  valid_preds = np.concatenate(valid_preds, axis=0)#.reshape(-1)
  all_train_preds.append(train_preds)
  all_valid_preds.append(valid_preds)


# all_train_preds = np.stack(all_train_preds, axis=1)
# all_valid_preds = np.stack(all_valid_preds, axis=1)
all_train_preds = np.stack(all_train_preds, axis=2)
all_valid_preds = np.stack(all_valid_preds, axis=2)
flattened_train_preds = all_train_preds.reshape(-1, len(seeds))
flattened_valid_preds = all_valid_preds.reshape(-1, len(seeds))

# Copy the original states.bin to the new repo
os.makedirs(save_dir_from_state, exist_ok=True)
os.makedirs(save_dir_from_sequence, exist_ok=True)
# using os

train_states = np.memmap(os.path.join(data_dir, 'train_states.bin'), dtype=np.uint8, mode='r', shape=(num_train_examples, seq_len, num_state_dimensions))
train_states = train_states[:, :-1, :]
reshaped_train_states = train_states.reshape(-1, num_state_dimensions).astype(np.uint8)
valid_states = np.memmap(os.path.join(data_dir, 'val_states.bin'), dtype=np.uint8, mode='r', shape=(num_val_examples, seq_len, num_state_dimensions))
valid_states = valid_states[:, :-1, :]
reshaped_valid_states = valid_states.reshape(-1, num_state_dimensions).astype(np.uint8)

###### SAVE THINGS FOR TRAINING OTHELLO FROM SEQUENCE ######
# train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint8, mode='r', shape=(num_train_examples, seq_len))
# train_data = train_data[:, :-1]
os.system(f'cp data/{dataset}/train.bin {save_dir_from_sequence}/train.bin')
os.system(f'cp data/{dataset}/val.bin {save_dir_from_sequence}/val.bin')
data_config_for_sequence = data_config.copy()
data_config_for_sequence['num_states'] = 1
data_config_for_sequence['num_state_dimensions'] = len(seeds)
with open(f'{save_dir_from_sequence}/config.json', 'w') as f:
    json.dump(data_config_for_sequence, f)

# Save the preds as memmaps to state
train_states_file_name = os.path.join(save_dir_from_sequence, "train_states.bin")
# Pad 0s to the end of all_train_preds
pad = np.zeros((all_train_preds.shape[0], 1, all_train_preds.shape[2]), dtype=np.float16)
all_train_preds = np.concatenate([all_train_preds, pad], axis=1)
arr = np.memmap(train_states_file_name, dtype=np.float16, mode='w+', shape=all_train_preds.shape)
arr[:] = all_train_preds[:]
arr.flush()

valid_states_file_name = os.path.join(save_dir_from_sequence, "val_states.bin")
pad = np.zeros((all_valid_preds.shape[0], 1, all_valid_preds.shape[2]), dtype=np.float16)
all_valid_preds = np.concatenate([all_valid_preds, pad], axis=1)
arr = np.memmap(valid_states_file_name, dtype=np.float16, mode='w+', shape=all_valid_preds.shape)
arr[:] = all_valid_preds[:]
arr.flush()


# Save the rest of the stuff as usual
# num_moves_in_game = (reshaped_train_states != 1).sum(-1)
# np.corrcoef(num_moves_in_game, train_preds)  # 0.32 for random, -0.01 for NTP, -0.05 for state

data_config['num_train_examples'] = len(reshaped_train_states)
data_config['num_val_examples'] = len(reshaped_valid_states)
data_config['num_heads'] = len(seeds)

with open(f'{save_dir_from_state}/config.json', 'w') as f:
    json.dump(data_config, f)

# Save the new preds.bin as memmaps
train_preds_file_name = os.path.join(save_dir_from_state, "train_y.bin")
arr = np.memmap(train_preds_file_name, dtype=np.float16, mode='w+', shape=flattened_train_preds.shape)
arr[:] = flattened_train_preds[:]
arr.flush()

valid_preds_file_name = os.path.join(save_dir_from_state, "val_y.bin")
arr = np.memmap(valid_preds_file_name, dtype=np.float16, mode='w+', shape=flattened_valid_preds.shape)
arr[:] = flattened_valid_preds[:]
arr.flush()

# And same with states
train_states_file_name = os.path.join(save_dir_from_state, "train_states.bin")
arr = np.memmap(train_states_file_name, dtype=np.uint8, mode='w+', shape=reshaped_train_states.shape)
arr[:] = reshaped_train_states[:]
arr.flush()

valid_states_file_name = os.path.join(save_dir_from_state, "val_states.bin")
arr = np.memmap(valid_states_file_name, dtype=np.uint8, mode='w+', shape=reshaped_valid_states.shape)
arr[:] = reshaped_valid_states[:]
arr.flush()

print(f"Saved data to {save_dir_from_state} and {save_dir_from_sequence}")
