import os
import pickle
import argparse
import numpy as np
import torch
from contextlib import nullcontext
from IPython import embed
from model import GPTConfig, GPT
from q_learning_utils import QLearningAgent, train_rl

def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ("yes", "true", "t", "y", "1"):
        return True
    elif v.lower() in ("no", "false", "f", "n", "0"):
        return False
    else:
        raise argparse.ArgumentTypeError("Boolean value expected.")
    
parser = argparse.ArgumentParser(description='Training of the NanoGPT.')

parser.add_argument('--dataset', type=str, default='random_walk_4', help='Name of the dataset to use')  
parser.add_argument('--n_layer', type=int, default=1, help='Number of layers (default: 1)')  
parser.add_argument('--n_head', type=int, default=1, help='Number of attention heads (default: 1)')  
parser.add_argument('--n_embd', type=int, default=120, help='Size of the embeddings (default: 384)')
parser.add_argument('--device', type=str, default='cuda:5')

parser.add_argument('--num_nodes', type=int, default=100, help='Number of Nodes (default: 100)')
parser.add_argument('--num_episodes', type=int, default=100000)

parser.add_argument('--fix_att', type=str2bool, default='False')
parser.add_argument('--load_ckpt_num', type=int, default=-1)

parser.add_argument('--epsilon', type=float, default=0.0)
parser.add_argument('--lr', type=float, default=3e-4)
parser.add_argument('--gamma', type=float, default=0.95)
parser.add_argument('--soft', type=str2bool, default=True)

parser.add_argument('--train_type', type=str, default='simple')
parser.add_argument('--reward_type', type=str, default='step')

parser.add_argument('--eval_interval', type=int, default=2000)
parser.add_argument('--eval_temperature', type=float, default=0.00001)
parser.add_argument('--eval_batch_size', type=int, default=500)
parser.add_argument('--eval_type_data', type=str, default='simple_test')
args = parser.parse_args()

dataset = args.dataset
n_layer = args.n_layer
n_head = args.n_head
n_embd = args.n_embd
num_nodes = args.num_nodes
train_start_path = args.train_type
reach_type = 'true'
device = args.device

if args.load_ckpt_num == -1:
    from_pretrain = False
else:
    from_pretrain = True
    
if args.fix_att:
    out_dir = f'qlearning_out/fix_att/{dataset}/from_{args.n_layer}_{args.n_head}_{args.n_embd}_{args.load_ckpt_num}_{args.epsilon}'
    ckpt_path = os.path.join(f'./pretrained_out/fix_att/{args.dataset}/simple/{args.n_layer}_{args.n_head}_{args.n_embd}_{args.num_nodes}_1_1/{args.load_ckpt_num}_ckpt.pt')
else:
    out_dir = f'qlearning_out/unfix_att/{dataset}/from_{args.n_layer}_{args.n_head}_{args.n_embd}_{args.load_ckpt_num}_{args.epsilon}'
    ckpt_path = os.path.join(f'./pretrained_out/unfix_att/{args.dataset}/simple/{args.n_layer}_{args.n_head}_{args.n_embd}_{args.num_nodes}_1_1/{args.load_ckpt_num}_ckpt.pt')

if args.soft:
    out_dir = out_dir + '_soft'
else:
    out_dir = out_dir + '_hard'

out_dir = out_dir + f'_{args.reward_type}_{args.train_type}'

os.makedirs(out_dir, exist_ok=True)
with open(out_dir + '/args.pkl', 'wb') as f:
        pickle.dump(args, f)
        
data_dir = os.path.join('data', f'{dataset}/{num_nodes}_1_1')
meta_path = os.path.join(data_dir, f'{train_start_path}_meta.pkl')
meta_vocab_size = None
if os.path.exists(meta_path):
    with open(meta_path, 'rb') as f:
        meta = pickle.load(f)
    meta_vocab_size = meta['vocab_size']
    print(f"found vocab_size = {meta_vocab_size} (inside {meta_path})")
block_size = meta['block_size']
stoi, itos = meta['stoi'], meta['itos']
decode = lambda l: ''.join([itos[i] for i in l])

dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+
bias = False # do we use bias inside LayerNorm and Linear layers?
backend = 'nccl' # 'nccl', 'gloo', etc.
dtype = 'bfloat16'
compile = False
# -----------------------------------------------------------------------------
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
#exec(open('configurator.py').read()) # overrides from command line or config file
config = {k: globals()[k] for k in config_keys} # will be useful for logging
# ----------------------------------------------------------------------------
# if not ddp, we are running on a single gpu, and one process

os.makedirs(out_dir, exist_ok=True)
torch.manual_seed(42)
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
# note: float16 data type will automatically use a GradScaler
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)

train_start = np.memmap(os.path.join(data_dir, f'{train_start_path}_train.bin'), dtype=np.uint16, mode='r')
train = torch.stack([torch.from_numpy(train_start.astype(np.int64))]).to(args.device)
train_pairs = train.view(train.shape[1]//(block_size+1),block_size+1)
train_pairs = train_pairs[:, :2]

######
# train_pairs = torch.load('start.pt').to(args.device) - 2
######

# init these up here, can override if init_from='resume' (i.e. from a checkpoint)
iter_num = 0
best_val_loss = 1e9
    
# model init
model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size, bias=bias, vocab_size=None, dropout=dropout) # start with model_args from command line

if meta_vocab_size is None:
    print("defaulting to vocab_size of GPT-2 to 50304 (50257 rounded up for efficiency)")
model_args['vocab_size'] = meta_vocab_size if meta_vocab_size is not None else 50304
gptconf = GPTConfig(**model_args)

model = GPT(gptconf, args.fix_att)
if from_pretrain:
    checkpoint = torch.load(ckpt_path, map_location=device)
    model.load_state_dict(checkpoint['model'])
model.to(device)

if block_size < model.config.block_size:
    model.crop_block_size(block_size)
    model_args['block_size'] = block_size # so that the checkpoint will have the right value

true_adj = np.load(f'{data_dir}/true_adj_matrix.npy')
true_reach = np.load(f'{data_dir}/{reach_type}_reach_matrix.npy')
for i in range(meta_vocab_size-2):
    true_reach[i][i] = 1


agent = QLearningAgent(model, device, args.lr, args.gamma)

# Train
train_rl(agent, train_pairs, true_adj, true_reach, out_dir=out_dir, model_args=model_args, config=config, soft=args.soft, args=args)