import argparse
from datetime import datetime
import json
import os
import pickle
import random

import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset
import wandb

from data import StandardDataGenerator, PauseDataGenerator, generate_data
from models import TransformerModel

parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('--mode', type=str, required=True, help='Mode of operation', choices=['standard', 'pause'])
parser.add_argument('--n_bits', type=int, default=50, help='Number of bits')
parser.add_argument('--non_causal', action='store_true', help='Use non causal transformer')
parser.add_argument('--seed', type=int, default=0, help='Random seed')
args = parser.parse_args()

seed = args.seed

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ['PYTHONHASHSEED'] = str(seed)

mode = args.mode
n_bits = args.n_bits
pos_enc = args.pos_enc
is_causal = not args.non_causal

print('Mode: ', mode)

epochs = 50
n_layer = 2
n_head = 4
n_embd = 32
n_dims = 4
pos_enc = 'learnedYes'
hint = True
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if not os.path.exists(f'data/sequences_{n_bits}_{seed}.pkl'):
    datasets = generate_data(n_bits, 500000, 5000, 50000)
    with open(f'data/sequences_{n_bits}_{seed}.pkl', 'wb+') as f:
        pickle.dump(datasets, f)
else:
    with open(f'data/sequences_{n_bits}_{seed}.pkl', 'rb') as f:
        datasets = pickle.load(f)

train, val, test = datasets['train'], datasets['val'], datasets['test']
if 'cot' in mode:
    mask_percentage = 0.5
    input_len = 2 * n_bits + 1
else:
    mask_percentage = 0.0
    input_len = n_bits + 1

train_len = len(train)
val_len = len(val)
test_len = len(test)

if mode == 'standard':
    train_loader = StandardDataGenerator(data=train, n_bits=n_bits, d_size=train_len, shuffle=True, hint=hint, is_causal=is_causal)
    val_loader = StandardDataGenerator(data=val, n_bits=n_bits, d_size=val_len, hint=hint, is_causal=is_causal)
    test_loader = StandardDataGenerator(data=test, n_bits=n_bits, d_size=test_len, hint=hint, is_causal=is_causal)
elif mode == 'pause':
    train_loader = PauseDataGenerator(data=train, n_bits=n_bits, d_size=train_len, shuffle=True, is_train=True, mask_percentage=mask_percentage)
    val_loader = PauseDataGenerator(data=val, n_bits=n_bits, d_size=val_len, is_train=False)
    test_loader = PauseDataGenerator(data=test, n_bits=n_bits, d_size=test_len, is_train=False)
else:
    raise ValueError(f"Invalid mode: {args.mode}. Choose 'standard' or 'pause'.")

print('Data loaded')

lr = 5e-4
wd = 0.0
betas = (0.9, 0.999)

model = TransformerModel(n_dims, input_len, n_embd, n_layer, n_head, pos_enc, is_causal=is_causal).to(device)
optim = torch.optim.Adam(model.parameters(), lr=lr)

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
model_name = f"model_{timestamp}"

config = {
        "name": model_name,
        "mode": mode,
        "n_bits": n_bits,
        "input_len": input_len,
        "mask_percentage": mask_percentage,
        "pos_enc": pos_enc,
        "epochs": epochs,
        "n_layer": n_layer,
        "n_head": n_head,
        "n_embd": n_embd,
        "n_dims": n_dims,
        "lr": lr,
        "wd": wd,
        "adam_betas": betas,
        "causal": is_causal,
        "hint": hint,
        "n_train": len(train),
        "n_val": len(val),
        "n_test": len(test),
    }

# Initialize wandb
wandb.init(
    project="parity-pause",
    config=config
)

config_path = os.path.join('models', f'{model_name}_config.json')
with open(config_path, 'w') as f:
    json.dump(config, f)


loss_fn = torch.nn.BCEWithLogitsLoss()

def hint_loss(pred, target, lengths):
    range_tensor = torch.arange(0, pred.shape[1], dtype=torch.float32, device=pred.device).view(1, -1, 1)
    lengths = lengths.view(-1, 1, 1)
    mask = ((range_tensor <= (lengths + 1))).float()
    loss = torch.nn.functional.binary_cross_entropy_with_logits(pred, target.float(), reduction='none')
    # Apply mask
    masked_loss = loss * mask
    return masked_loss.sum() / mask.sum()

def pause_loss(pred, target, lengths):
    range_tensor = torch.arange(0, pred.shape[1], dtype=torch.float32, device=pred.device).view(1, -1, 1)
    lengths = lengths.view(-1, 1, 1)
    mask = ((range_tensor >= (lengths)) & (range_tensor < (2 * lengths + 1))).float()
    loss = torch.nn.functional.binary_cross_entropy_with_logits(pred, target.float(), reduction='none')
    # Apply mask
    masked_loss = loss * mask
    return masked_loss.sum() / mask.sum()


print('Model initialized')

def train_standard(x, y, l):
    pred = model(x)
    loss = loss_fn(pred[torch.arange(pred.shape[0]), (l - 1)].unsqueeze(2), y.float())
    optim.zero_grad()
    loss.backward()
    optim.step()
    return loss.item()


def train_hint(x, y, l):
    pred = model(x)
    loss = hint_loss(pred, y, l)
    optim.zero_grad()
    loss.backward()
    optim.step()
    return loss.item()


def train_pause(x, y, l):
    pred = model(x)
    loss = pause_loss(pred, y, l)
    optim.zero_grad()
    loss.backward()
    # torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optim.step()
    return loss.item()


def evaluate(x, y, lengths):
    with torch.no_grad():
        pred = (model(x) >= 0).float()
        return (pred[torch.arange(pred.shape[0]), lengths].unsqueeze(2) == y).float()


def evaluate_hint(x, y, lengths):
    with torch.no_grad():
        pred = (model(x) >= 0).float()
        pred = pred[torch.arange(pred.shape[0]), lengths].unsqueeze(2)
        y = y[torch.arange(y.shape[0]), lengths].unsqueeze(2)
        return (pred == y).float()


def evaluate_pause(x, y, lengths):
    with torch.no_grad():
        pred = (model(x) >= 0).float()
        return (pred[torch.arange(pred.shape[0]), 2 * lengths] == y[torch.arange(y.shape[0]), 2 * lengths]).float()


if mode == 'pause':
    train_fn = train_pause
    eval_fn = evaluate_pause
elif hint:
    train_fn = train_hint
    eval_fn = evaluate_hint
else:
    train_fn = train_standard
    eval_fn = evaluate


def eval_loader(loader):
    tested = []
    for (x, y, l) in loader:
        x = torch.nn.functional.one_hot(x.long() + 1, n_dims).float().squeeze()

        x, y, l = x.to(device), y.to(device), l.to(device)
        eq = eval_fn(x, y, l)
        tested.append(eq)

    return torch.mean(torch.cat(tested, dim=0)).item()


print('Training')

lr_adjusted = False

max_val_acc = 0
for epoch in range(epochs):
    epoch_loss = 0
    epoch_acc = 0
    num_batches = 0

    for i, (x, y, l) in enumerate(train_loader):
        x = torch.nn.functional.one_hot(x.long() + 1, n_dims).float().squeeze()

        x, y, l = x.to(device), y.to(device), l.to(device)
        loss = train_fn(x, y, l)
        acc = torch.mean(eval_fn(x, y, l))
        epoch_loss += loss
        epoch_acc += acc
        num_batches += 1

        wandb.log({
            "batch/loss": loss,
            "batch/accuracy": acc,
        })

        print(f"Epoch: {epoch}, Batch: {i}, Loss: {loss}, Acc: {acc}")

    # scheduler.step()
    val_acc = eval_loader(val_loader)
    if val_acc > max_val_acc:
        torch.save({k: v.cpu() for k, v in model.state_dict().items()}, f"models/model_{timestamp}_max.pt")
        max_val_acc = val_acc

    if epoch % 5 == 0:
        torch.save({k: v.cpu() for k, v in model.state_dict().items()}, f"models/model_{timestamp}_epoch_{epoch}.pt")

    # Log epoch metrics
    wandb.log({
        "epoch/loss": epoch_loss / num_batches,
        "epoch/train_acc": epoch_acc / num_batches,
        "epoch/val_acc": val_acc,
        "epoch": epoch
    })

print('Evaluating')

test_accuracy = eval_loader(test_loader)
wandb.log({"test_accuracy": test_accuracy})
print(f"Test accuracy: {test_accuracy}")

torch.save({k: v.cpu() for k, v in model.state_dict().items()}, f"models/model_{timestamp}_final.pt")

if os.path.exists(f"models/model_{timestamp}_max.pt"):
    state_dict = torch.load(f"models/model_{timestamp}_max.pt", map_location=device)
    model.load_state_dict(state_dict)

    test_accuracy = eval_loader(test_loader)
    wandb.log({"test_accuracy_max": test_accuracy})
    print(f"Test accuracy (max): {test_accuracy}")

# Close wandb run
wandb.finish()
