import os
import numpy as np
import pickle
from contextlib import nullcontext
import torch
from model import GPTConfig, GPT
from othello_mamba.models.mamba.mamba import MambaConfig, Mamba
import pdb
import json
import tqdm
from collections import defaultdict
from sklearn.model_selection import train_test_split

device = 'cuda'
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16'
batch_size = 1024

pretrain = 'state'  # 'NTP' or 'state' or 'random
architecture = 'mamba' # 'transformer' or 'mamba' or 'mamba2'

architecture_name = architecture + '-' if architecture != 'transformer' else ''

dataset = 'gridworld/5-states-restrict-moves'

save_dir_from_state = f'data/{dataset}/white-noise-{pretrain}-{architecture}'
save_dir_from_sequence = f'data/{dataset}/white-noise-{pretrain}-{architecture}-seq'

seeds = [1, 2, 3, 4, 5]#, 6, 7, 8, 9, 10]
ckpt_num = 1000 # 1000 or 1

all_train_preds = []
all_valid_preds = []

for seed in seeds: 
  print(f"Working on seed {seed}")
  if pretrain == 'NTP':
    # name = f'{architecture_name}5-examples-transfer-gridworld-NTP-only-to-white-noise-seed-{seed}'
    name = f'{architecture_name}500-examples-transfer-gridworld-NTP-only-to-white-noise-seed-{seed}'
  elif pretrain == 'state':
    # name = f'{architecture_name}5-examples-transfer-gridworld-state-only-to-white-noise-seed-{seed}'
    name = f'{architecture_name}500-examples-transfer-gridworld-state-only-to-white-noise-seed-{seed}'
  elif pretrain == 'random':
    # name = f'{architecture_name}5-examples-transfer-gridworld-random-init-to-white-noise-seed-{seed}'
    pass

  out_dir = f'out/{name}/'

  torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
  torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
  device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
  ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
  ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)

  ckpt_path = os.path.join(out_dir, f'ckpt_{ckpt_num}.pt')
  checkpoint = torch.load(ckpt_path, map_location=device)
  if architecture == 'transformer':
      gptconf = GPTConfig(**checkpoint['model_args'])
      model = GPT(gptconf)
  else:
      conf = MambaConfig(**checkpoint['model_args'])
      model = Mamba(conf)
  state_dict = checkpoint['model']
  unwanted_prefix = '_orig_mod.'
  for k,v in list(state_dict.items()):
      if k.startswith(unwanted_prefix):
          state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)

  model.load_state_dict(state_dict)
  model.to(device)
  model.eval()

  data_dir = os.path.join('data', dataset)
  data_config_file = os.path.join(data_dir, 'config.json')
  with open(data_config_file, 'r') as f:
      data_config = json.load(f)

  num_state_dimensions = data_config['num_state_dimensions']

  pad_id = data_config['pad_id']
  vocab_size = data_config['vocab_size']# + 1
  num_train_examples = data_config['num_train_examples']
  num_val_examples = data_config['num_val_examples']
  seq_len = data_config['seq_len']
  num_state_dimensions = data_config['num_state_dimensions']

  def get_batch(start_ind, end_ind, split):
      if split == 'train':
          data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint8, mode='r', shape=(num_train_examples, seq_len))
          all_states = np.memmap(os.path.join(data_dir, 'train_states.bin'), dtype=np.uint8, mode='r', shape=(len(data), seq_len, num_state_dimensions))
      elif split == 'valid':
          data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint8, mode='r', shape=(num_val_examples, seq_len))
          all_states = np.memmap(os.path.join(data_dir, 'val_states.bin'), dtype=np.uint8, mode='r', shape=(len(data), seq_len, num_state_dimensions))
      x = torch.from_numpy((data[start_ind:end_ind, :-1]).astype(np.int64))
      y = torch.from_numpy((data[start_ind:end_ind, 1:]).astype(np.int64))
      states = torch.from_numpy(all_states[start_ind:end_ind, :, :].astype(np.int64))
      x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
      states = states.pin_memory().to(device, non_blocking=True)
      return x, y, states

  train_preds = []
  bar = tqdm.tqdm(range(0, num_train_examples, batch_size))
  with torch.no_grad():
    for i in bar:
        x, y, _ = get_batch(i, i + batch_size, 'train')
        b, t = x.size()
        with ctx:
            _, _, _, state_preds = model(x)
            train_preds.append(state_preds.softmax(-1)[:, :, 1].float().cpu().numpy())

  valid_preds = []
  bar = tqdm.tqdm(range(0, num_val_examples, batch_size))
  with torch.no_grad():
    for i in bar:
        x, y, states = get_batch(i, i + batch_size, 'valid')
        b, t = x.size()
        with ctx:
            _, _, _, state_preds = model(x)
            valid_preds.append(state_preds.softmax(-1)[:, :, 1].float().cpu().numpy())

  train_preds = np.concatenate(train_preds, axis=0)#.reshape(-1)
  valid_preds = np.concatenate(valid_preds, axis=0)#.reshape(-1)
  all_train_preds.append(train_preds)
  all_valid_preds.append(valid_preds)

# all_train_preds = np.stack(all_train_preds, axis=1)
# all_valid_preds = np.stack(all_valid_preds, axis=1)
all_train_preds = np.stack(all_train_preds, axis=2)
all_valid_preds = np.stack(all_valid_preds, axis=2)
flattened_train_preds = all_train_preds.reshape(-1, len(seeds))
flattened_valid_preds = all_valid_preds.reshape(-1, len(seeds))

# Copy the original states.bin to the new repo
os.makedirs(save_dir_from_state, exist_ok=True)
os.makedirs(save_dir_from_sequence, exist_ok=True)
# using os

train_states = np.memmap(os.path.join(data_dir, 'train_states.bin'), dtype=np.uint8, mode='r', shape=(num_train_examples, seq_len, num_state_dimensions))
train_states = train_states[:, :-1, :]
reshaped_train_states = train_states.reshape(-1, num_state_dimensions).astype(np.uint8)
valid_states = np.memmap(os.path.join(data_dir, 'val_states.bin'), dtype=np.uint8, mode='r', shape=(num_val_examples, seq_len, num_state_dimensions))
valid_states = valid_states[:, :-1, :]
reshaped_valid_states = valid_states.reshape(-1, num_state_dimensions).astype(np.uint8)

###### SAVE THINGS FOR TRAINING OTHELLO FROM SEQUENCE ######
# train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint8, mode='r', shape=(num_train_examples, seq_len))
# train_data = train_data[:, :-1]
os.system(f'cp data/{dataset}/train.bin {save_dir_from_sequence}/train.bin')
os.system(f'cp data/{dataset}/val.bin {save_dir_from_sequence}/val.bin')
data_config_for_sequence = data_config.copy()
data_config_for_sequence['num_states'] = 1
data_config_for_sequence['num_state_dimensions'] = len(seeds)
with open(f'{save_dir_from_sequence}/config.json', 'w') as f:
    json.dump(data_config_for_sequence, f)

# Save the preds as memmaps to state
train_states_file_name = os.path.join(save_dir_from_sequence, "train_states.bin")
# Pad 0s to the end of all_train_preds
pad = np.zeros((all_train_preds.shape[0], 1, all_train_preds.shape[2]), dtype=np.float16)
all_train_preds = np.concatenate([all_train_preds, pad], axis=1)
arr = np.memmap(train_states_file_name, dtype=np.float16, mode='w+', shape=all_train_preds.shape)
arr[:] = all_train_preds[:]
arr.flush()

valid_states_file_name = os.path.join(save_dir_from_sequence, "val_states.bin")
pad = np.zeros((all_valid_preds.shape[0], 1, all_valid_preds.shape[2]), dtype=np.float16)
all_valid_preds = np.concatenate([all_valid_preds, pad], axis=1)
arr = np.memmap(valid_states_file_name, dtype=np.float16, mode='w+', shape=all_valid_preds.shape)
arr[:] = all_valid_preds[:]
arr.flush()


# Save the rest of the stuff as usual
# num_moves_in_game = (reshaped_train_states != 1).sum(-1)
# np.corrcoef(num_moves_in_game, train_preds)  # 0.32 for random, -0.01 for NTP, -0.05 for state

data_config['num_train_examples'] = len(reshaped_train_states)
data_config['num_val_examples'] = len(reshaped_valid_states)
data_config['num_heads'] = len(seeds)

with open(f'{save_dir_from_state}/config.json', 'w') as f:
    json.dump(data_config, f)

# Save the new preds.bin as memmaps
train_preds_file_name = os.path.join(save_dir_from_state, "train_y.bin")
arr = np.memmap(train_preds_file_name, dtype=np.float16, mode='w+', shape=flattened_train_preds.shape)
arr[:] = flattened_train_preds[:]
arr.flush()

valid_preds_file_name = os.path.join(save_dir_from_state, "val_y.bin")
arr = np.memmap(valid_preds_file_name, dtype=np.float16, mode='w+', shape=flattened_valid_preds.shape)
arr[:] = flattened_valid_preds[:]
arr.flush()

# And same with states
train_states_file_name = os.path.join(save_dir_from_state, "train_states.bin")
arr = np.memmap(train_states_file_name, dtype=np.uint8, mode='w+', shape=reshaped_train_states.shape)
arr[:] = reshaped_train_states[:]
arr.flush()

valid_states_file_name = os.path.join(save_dir_from_state, "val_states.bin")
arr = np.memmap(valid_states_file_name, dtype=np.uint8, mode='w+', shape=reshaped_valid_states.shape)
arr[:] = reshaped_valid_states[:]
arr.flush()

print(f"Saved data to {save_dir_from_state} and {save_dir_from_sequence}")

val_y = flattened_valid_preds
var_y = np.mean(((val_y - np.mean(val_y, axis=0)) ** 2))

train_y = flattened_train_preds
train_x = reshaped_train_states
val_x = reshaped_valid_states

train_x = np.array(train_x).reshape(-1)
train_y = np.array(train_y)

output = np.zeros((5, len(seeds)))  # Initialize with 5 rows (0 to 4 inclusive)
counts = np.zeros(5)

for i in range(5):
    mask = (train_x == i)
    output[i] = train_y[mask].mean(axis=0)
    counts[i] = mask.sum()

preds = output[val_x][:, 0, :]
mse = np.mean((preds - val_y) ** 2)
r2 = 1 - mse / var_y 
r2  

print(f"R²: {r2:.4f}")


