import os
import shutil
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from model import Tranformer
from data import CustomDataset
import random
import numpy as np

# A -> seed 43
# B -> seed 42
# C -> seed 44
seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
random.seed(seed)
np.random.seed(seed)


N = 4  # Number of symbols per indirection level
D = 4  # Number of indirection levels (D = 1 is the basic induction head)

SYMBOL_DIM = 4
EMBED_SIZE = 128
N_HEADS = 8
N_LAYERS = 8

IMPLICIT_CURRICULUM = True

epoch_size = 1024
save_interval = 100000

num_epochs = 20000

block_size = N * (D * 2 + 1)
dataset_size = 1024 * 256
batch_size = 128

valid_size = 128

mlp = True
head_dim = None

output_folder = "models/model"

device = "cuda" if torch.cuda.is_available() else "cpu"
print("device is: ", device)

train_dataloader = DataLoader(CustomDataset(N, D, SYMBOL_DIM, IMPLICIT_CURRICULUM, dataset_size), batch_size=batch_size, shuffle=True)
valid_dataloader = DataLoader(CustomDataset(N, D, SYMBOL_DIM, IMPLICIT_CURRICULUM, valid_size), batch_size=valid_size, shuffle=False)

output_dim = SYMBOL_DIM * D if IMPLICIT_CURRICULUM else SYMBOL_DIM
model = Tranformer(SYMBOL_DIM, output_dim, EMBED_SIZE, N_HEADS, block_size, N_LAYERS, mlp=mlp, head_dim=head_dim).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.1)
criterion = torch.nn.MSELoss()

if os.path.exists(output_folder):
    shutil.rmtree(output_folder)
os.makedirs(output_folder, exist_ok=True)

log = open(f"{output_folder}/log.txt", "w")

bar = tqdm(total=num_epochs)
global_step = 0
for epoch in range(num_epochs + 1):
    train_loss = 0
    for x, y in train_dataloader:
        x = x.to(device)
        y = y.to(device)
        optimizer.zero_grad()
        y_pred = model(x)
        y_pred = y_pred[:, -N:]
        loss = criterion(y_pred, y)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

        global_step += batch_size
        if global_step % epoch_size == 0:
            break

    valid_loss = 0
    position_losses = [0 for _ in range(D)]
    model.eval()
    for x, y in valid_dataloader:
        x = x.to(device)
        y = y.to(device)
        y_pred = model(x)[:, -N:]
        loss = criterion(y_pred, y)

        if IMPLICIT_CURRICULUM:
            y_split = y.split(SYMBOL_DIM, dim=2)
            y_pred_split = y_pred.split(SYMBOL_DIM, dim=2)
            for i in range(D):
                position_losses[i] += criterion(y_pred_split[i], y_split[i]).item()

        valid_loss += loss.item()
    model.train()

    train_loss /= epoch_size / batch_size
    valid_loss /= len(valid_dataloader)

    bar.set_description(f"Loss: {train_loss:.4f}, Valid loss: {valid_loss:.4f}")
    bar.update(1)

    log.write(f"{epoch},{train_loss},{valid_loss},{','.join(str(l / len(valid_dataloader)) for l in position_losses)}\n")
    log.flush()

    save_interval = 20 if epoch < 1000 else 50
    save_interval = 10
    if epoch % save_interval == 0:
        torch.save(model.state_dict(), f"{output_folder}/{epoch}.pt")
