import gc
import os
import time
import math
import pickle
from contextlib import nullcontext

import numpy as np
import torch
import re
from copy import deepcopy
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
from tqdm import tqdm
# import wandb
# from model import GPTConfig, GPT
import sys
sys.path.append("src/pycharmprojects/mass")
from model import GPTConfig, GPT
import gc
import argparse
import matplotlib.pyplot as plt
from torch.nn.functional import softmax
import matplotlib.colors as mcolors


out_dir = '/checkpoints/medium'
save_dir = '/medium/attnmaps'

ckpt_name = 'ckpt_full.pt'
eval_interval = 1000
log_interval = 1

eval_iters = 200
eval_only = True  # if True, script exits right after the first eval
always_save_checkpoint = False  # if True, always save a checkpoint after each eval
init_from = 'resume'  # 'scratch' or 'resume' or 'gpt2*'
# wandb logging
wandb_log = False  # disabled by default
wandb_project = 'owt'
wandb_run_name = 'gpt2'  # 'run' + str(time.time())
# data
dataset = 'openwebtext'
gradient_accumulation_steps = 1  # used to simulate larger batch sizes, was 5 earlier
batch_size = 1  # if gradient_accumulation_steps > 1, this is the micro-batch size, was 12 earlier
# block_size = 1024, 100
block_size = 100
# model small
n_layer = 12
n_head = 12
n_embd = 768
dropout = 0.0  # for pretraining 0 is good, for finetuning try 0.1+
bias = False  # do we use bias inside LayerNorm and Linear layers?
# optimizer
optimizer_name = 'adamw'
learning_rate = 3e-4  # max learning rate, earlier it was 6e-4
max_iters = 100000  # total number of training iterations, earlier it was 600000
weight_decay = 1e-1
beta1 = 0.9
beta2 = 0.95
grad_clip = 1.0  # clip gradients at this value, or disable if == 0.0
rho = 0.1
interval = 10
variant = 4
# learning rate decay settings
decay_lr = True  # whether to decay the learning rate
warmup_iters = 2000  # how many steps to warm up for
lr_decay_iters = 100000  # should be ~= max_iters per Chinchilla, it was 600000 earlier
min_lr = 6e-5  # minimum learning rate, should be ~= learning_rate/10 per Chinchilla
# DDP settings
backend = 'nccl'  # 'nccl', 'gloo', etc.
ddp = False
# system
device = 'cuda'  # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks
dtype = 'bfloat16'  # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler
compile = True  # use PyTorch 2.0 to compile the model to be faster
scale_attn_by_inverse_layer_idx = True
# args = parser.parse_args()
# -----------------------------------------------------------------------------
config_keys = [k for k, v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
# exec(open('configurator.py').read())  # overrides from command line or config file
config = {k: globals()[k] for k in config_keys}  # will be useful for logging
# -----------------------------------------------------------------------------

# Ensure the directory exists
if not os.path.exists(out_dir):
    print(f"Directory {out_dir} does not exist!")
    exit(1)

# various inits, derived attributes, I/O setup
if ddp:
    init_process_group(backend=backend)
    ddp_rank = int(os.environ['RANK'])
    ddp_local_rank = int(os.environ['LOCAL_RANK'])
    device = f'cuda:{ddp_local_rank}'
    torch.cuda.set_device(device)
    master_process = ddp_rank == 0  # this process will do logging, checkpointing etc.
    seed_offset = ddp_rank  # each process gets a different seed
else:
    # if not ddp, we are running on a single gpu, and one process
    master_process = True
    seed_offset = 0
    gradient_accumulation_steps *= 8  # simulate 8 gpus

if master_process:
    os.makedirs(out_dir, exist_ok=True)
torch.manual_seed(5000 + seed_offset)
torch.backends.cuda.matmul.allow_tf32 = True  # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True  # allow tf32 on cudnn
device_type = 'cuda' if 'cuda' in device else 'cpu'  # for later use in torch.autocast
# note: float16 data type will automatically use a GradScaler
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if device_type == 'cpu' else torch.autocast(device_type=device_type, dtype=ptdtype)

# poor man's data loader
data_dir = os.path.join('/Sophia/data', dataset)
train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')


# Initialize index trackers for the training dataset
def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([torch.from_numpy((data[i:i + block_size]).astype(np.int64)) for i in ix])
    y = torch.stack([torch.from_numpy((data[i + 1:i + 1 + block_size]).astype(np.int64)) for i in ix])
    if device_type == 'cuda':
        # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
        x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
    else:
        x, y = x.to(device), y.to(device)
    return x, y

    # init these up here, can override if init_from='resume' (i.e. from a checkpoint)


iter_num = 0
best_val_loss = 1e9

# attempt to derive vocab_size from the dataset
meta_path = os.path.join(data_dir, 'meta.pkl')
meta_vocab_size = None
if os.path.exists(meta_path):
    with open(meta_path, 'rb') as f:
        meta = pickle.load(f)
    meta_vocab_size = meta['vocab_size']
    print(f"found vocab_size = {meta_vocab_size} (inside {meta_path})")

# model init
model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size,
                  bias=bias, vocab_size=None, dropout=dropout,
                  scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx)  # start with model_args from command line

if init_from == 'scratch':
    # init a new model from scratch
    print("Initializing a new model from scratch")
    # determine the vocab size we'll use for from-scratch training
    if meta_vocab_size is None:
        print("defaulting to vocab_size of GPT-2 to 50304 (50257 rounded up for efficiency)")
    model_args['vocab_size'] = meta_vocab_size if meta_vocab_size is not None else 50304
    gptconf = GPTConfig(**model_args)
    model = GPT(gptconf)

# Initialize the EMA for the model
elif init_from == 'resume':
    print(f"Resuming training from {out_dir}")
    # resume training from a checkpoint.
    ckpt_path = os.path.join(out_dir, ckpt_name)
    checkpoint = torch.load(ckpt_path, map_location=device)
    checkpoint_model_args = checkpoint['model_args']
    # force these config attributes to be equal otherwise we can't even resume training
    # the rest of the attributes (e.g. dropout) can stay as desired from command line
    for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']:
        model_args[k] = checkpoint_model_args[k]
    # create the model
    gptconf = GPTConfig(**model_args)
    # gptconf.n_layer = 16  # changing the layers of GPT
    model = GPT(gptconf)
    state_dict = checkpoint['model']
    # fix the keys of the state dictionary :(
    # honestly no idea how checkpoints sometimes get this prefix, have to debug more
    unwanted_prefix = '_orig_mod.'
    for k, v in list(state_dict.items()):
        if k.startswith(unwanted_prefix):
            state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
    model.load_state_dict(state_dict)
    iter_num = checkpoint['iter_num']
    best_val_loss = checkpoint['best_val_loss']
elif init_from.startswith('gpt2'):
    print(f"Initializing from OpenAI GPT-2 weights: {init_from}")
    # initialize from OpenAI GPT-2 weights
    override_args = dict(dropout=dropout)
    model = GPT.from_pretrained(init_from, override_args)
    # read off the created config params, so we can store them into checkpoint correctly
    for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']:
        model_args[k] = getattr(model.config, k)
    # crop down the model block size if desired, using model surgery
if block_size < model.config.block_size:
    model.crop_block_size(block_size)
    model_args['block_size'] = block_size  # so that the checkpoint will have the right value
model.to(device)

# initialize a GradScaler. If enabled=False scaler is a no-op
scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))

# optimizer
optimizer = model.configure_optimizers(optimizer_name, weight_decay, learning_rate, (beta1, beta2), rho,
                                       device_type)
if init_from == 'resume':
    optimizer.load_state_dict(checkpoint['optimizer'])
    del state_dict
    del checkpoint
# compile the model
if compile:
    print("compiling the model... (takes a ~minute)")
    unoptimized_model = model
    model = torch.compile(model)  # requires PyTorch 2.0

# wrap model into DDP container
if ddp:
    model = DDP(model, device_ids=[ddp_local_rank])


# helps estimate an arbitrarily accurate loss over either split using many batches
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in tqdm(range(eval_iters), desc="Evaluating", ncols=100):
            X, Y = get_batch(split)
            with ctx:
                logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out


@torch.no_grad()
def compute_causal_attention_map(model, eval_iters, layer_index):
    model.eval()  # Set the model to evaluation mode

    # Create a directory named after eval_iters
    eval_iters_dir = os.path.join(save_dir, f'eval_iters_{eval_iters}')
    if not os.path.exists(eval_iters_dir):
        os.makedirs(eval_iters_dir)

    for iteration in tqdm(range(1, eval_iters + 1), desc="Computing Attention Maps"):
        X, _ = get_batch('val')  # Fetch validation data

        # Forward pass through the model up to the selected layer
        x = model.transformer.drop(model.transformer.wte(X) + model.transformer.wpe.weight[:X.size(1), :])
        for block in model.transformer.h[:layer_index]:
            x = block(x)

        # x is the output up to the specified layer
        block = model.transformer.h[layer_index]
        x = block.ln_1(x)  # Apply Layer normalization

        # Extracting queries, keys for all heads in the selected layer
        q, k, _ = block.attn.c_attn(x).chunk(3, dim=-1)
        head_dim = k.shape[-1] // block.attn.n_head
        q = q.view(q.size(0), q.size(1), block.attn.n_head, head_dim).transpose(1, 2)
        k = k.view(k.size(0), k.size(1), block.attn.n_head, head_dim).transpose(1, 2).transpose(-2, -1)

        # Compute attention scores for each head
        att = torch.matmul(q, k) / (head_dim ** 0.5)

        # Causal mask: ensure that each position can only attend to the past or present positions
        causal_mask = torch.tril(torch.ones((X.size(1), X.size(1)), device=att.device)).unsqueeze(0).unsqueeze(0)
        att = att.masked_fill(causal_mask == 0, float('-inf'))

        att = softmax(att, dim=-1)

        num_heads = att.size(1)
        for head_idx in range(num_heads):
            att_map = att[:, head_idx, :, :]  # attention map for this head

            # Save attention map
            att_filename = os.path.join(
                eval_iters_dir, f'layer{layer_index + 1}_head{head_idx + 1}_iter{iteration}_att.npy')
            np.save(att_filename, att_map.cpu().numpy())

            # print(f'Saved attention map and rank for layer {layer_index + 1}, head {head_idx + 1}, iteration {iteration}')

    model.train()  # Set the model back to training mode


# learning rate decay scheduler (cosine with warmup)
def get_lr(it):
    # 1) linear warmup for warmup_iters steps
    if it < warmup_iters:
        return learning_rate * it / warmup_iters
    # 2) if it > lr_decay_iters, return min learning rate
    if it > lr_decay_iters:
        return min_lr
    # 3) in between, use cosine decay down to min learning rate
    decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
    assert 0 <= decay_ratio <= 1
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))  # coeff ranges 0..1
    return min_lr + coeff * (learning_rate - min_lr)


# logging
if wandb_log and master_process:
    import wandb

    wandb.init(project=wandb_project, name=wandb_run_name, config=config)

if __name__ == "__main__":
    # parser = argparse.ArgumentParser(description="Plot Attention Maps for Specific Layer")
    # parser.add_argument('--layer_index', type=int, required=True, help='Index of the layer to plot')
    # args = parser.parse_args()

    # main training loop
    X, Y = get_batch('train')  # fetch the very first batch
    t0 = time.time()
    local_iter_num = 0  # number of iterations in the lifetime of this process
    raw_model = model.module if ddp else model  # unwrap DDP container if needed
    running_mfu = -1.0
    clip_time = 0

    # determine and set the learning rate for this iteration
    lr = get_lr(iter_num) if decay_lr else learning_rate
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

        # evaluate the loss on train/val sets and write checkpoints
    if iter_num % eval_interval == 0 and master_process:
        for layer_index in range(n_layer):
            compute_causal_attention_map(model=model, eval_iters=100, layer_index=layer_index)
        # Example usage (assuming a specific head_index and num_layers are defined)
