import os
import numpy as np
import pickle
from contextlib import nullcontext
import torch
from model import GPTConfig, GPT
from othello_mamba.models.mamba.mamba import MambaConfig, Mamba
import pdb
import json
import tqdm
from collections import defaultdict
from sklearn.model_selection import train_test_split

device = 'cuda'
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16'
batch_size = 1024

# pretrain = 'NTP'  # 'NTP' or 'state' or 'random
# architecture = 'transformer' # 'transformer' or 'mamba' or 'mamba2'
architectures = ['rnn', 'lstm', 'transformer', 'mamba', 'mamba2']
pretrains = ['NTP', 'state']
othello_type = 'championship' # synthetic or championship
state_transformation = 'parity' # majority or balance-black or parity
probe = False
seed = 1


num_examples = 500

othello_slug = 'synthetic-1M' if othello_type == 'synthetic' else 'championship'
cross_entropies = []
accuracies = []
cross_ent_dict = defaultdict(list)

for architecture in architectures:
    for pretrain in pretrains:
        print(f"Working on {architecture} with {pretrain} pretraining for probe={probe}")
        if probe:
            name = f'probe-{architecture}-{othello_slug}-transfer-{pretrain}-only-to-state-seed-{seed}'
            if othello_type == 'synthetic':
                dataset = f'othello/synthetic-othello-1M'
            elif othello_type == 'championship':
                dataset = f'othello/championship-othello'
        else:
            name = f'{architecture}-{othello_slug}-transfer-{pretrain}-only-to-{state_transformation}-seed-{seed}'
            if othello_type == 'synthetic':
                dataset = f'othello/synthetic-othello-1M/state-transform/{state_transformation}'
            elif othello_type == 'championship':
                dataset = f'othello/championship-othello/state-transform/{state_transformation}'


        out_dir = f'out/{name}/'

        torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
        torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
        device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
        ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
        ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)

        ckpt_path = os.path.join(out_dir, f'ckpt_best.pt')
        checkpoint = torch.load(ckpt_path, map_location=device)
        if architecture == 'transformer':
            gptconf = GPTConfig(**checkpoint['model_args'])
            model = GPT(gptconf)
        else:
            conf = MambaConfig(**checkpoint['model_args'])
            model = Mamba(conf)
        state_dict = checkpoint['model']
        unwanted_prefix = '_orig_mod.'
        for k,v in list(state_dict.items()):
            if k.startswith(unwanted_prefix):
                state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)

        model.load_state_dict(state_dict)
        model.to(device)
        model.eval()

        data_dir = os.path.join('data', dataset)
        data_config_file = os.path.join(data_dir, 'config.json')
        with open(data_config_file, 'r') as f:
            data_config = json.load(f)

        num_state_dimensions = data_config['num_state_dimensions']

        pad_id = data_config['pad_id']
        vocab_size = data_config['vocab_size']# + 1
        num_train_examples = data_config['num_train_examples']
        num_val_examples = data_config['num_val_examples']
        seq_len = data_config['seq_len']
        num_state_dimensions = data_config['num_state_dimensions']

        def get_batch(start_ind, end_ind, split):
            if split == 'train':
                data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint8, mode='r', shape=(num_train_examples, seq_len))
                all_states = np.memmap(os.path.join(data_dir, 'train_states.bin'), dtype=np.uint8, mode='r', shape=(len(data), seq_len, num_state_dimensions))
            elif split == 'valid':
                data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint8, mode='r', shape=(num_val_examples, seq_len))
                all_states = np.memmap(os.path.join(data_dir, 'val_states.bin'), dtype=np.uint8, mode='r', shape=(len(data), seq_len, num_state_dimensions))
            x = torch.from_numpy((data[start_ind:end_ind, :-1]).astype(np.int64))
            y = torch.from_numpy((data[start_ind:end_ind, 1:]).astype(np.int64))
            states = torch.from_numpy(all_states[start_ind:end_ind, :, :].astype(np.int64))
            x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
            states = states.pin_memory().to(device, non_blocking=True)
            return x, y, states

        valid_preds = []
        valid_labels = []
        with torch.no_grad():
          for i in range(0, num_val_examples, batch_size):
              x, y, states = get_batch(i, i + batch_size, 'valid')
              b, t = x.size()
              with ctx:
                  _, _, _, state_preds = model(x)
                  valid_preds.append(state_preds.softmax(-1).cpu().numpy())
                  valid_labels.append(states.cpu().numpy()[:, :-1, :])

        valid_labels = np.concatenate(valid_labels)
        valid_preds = np.concatenate(valid_preds)
        # Get cross entropy
        valid_labels = valid_labels.reshape(-1)
        one_hot_labels = np.zeros((len(valid_labels), data_config['num_states']))
        one_hot_labels[np.arange(len(valid_labels)), valid_labels] = 1
        valid_preds = valid_preds.reshape(-1, data_config['num_states'])
        cross_entropy = -np.sum(one_hot_labels * np.log(valid_preds + 1e-8)) / len(one_hot_labels)
        cross_ent_dict[(architecture, pretrain)].append(cross_entropy)
        cross_entropies.append(cross_entropy)
        accuracy = np.mean(np.argmax(valid_preds, axis=-1) == valid_labels)
        accuracies.append(accuracy)

ib_champ = [0.478, 0.507, 0.792, 0.746, 0.602, 0.734, 0.552, 0.847, 0.496, 0.693]

pdb.set_trace()


np.round(cross_entropies, 3)
np.corrcoef(cross_entropies, ib_champ)[0][1]

np.round(accuracies, 3)
np.corrcoef(accuracies, ib_champ)[0][1]