import os
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from model import Tranformer
from data import CustomDataset

N = 4  # Number of symbols per indirection level
SYMBOL_DIM = 4
EMBED_SIZE = 128  # SYMBOL_DIM * (N + 1) = 8 * 5 = 40
N_HEADS = 8
N_LAYERS = 8

epoch_size = 1024
save_interval = 100000

valid_size = 128

mlp = True
head_dim = None


# IC config
output_folder = "mega_run/d=5_IC"
IMPLICIT_CURRICULUM = True
D = 5  # Number of indirection levels (D = 1 is the basic induction head)

# Non-IC config
# output_folder = "mega_run/d=3_nonIC"
# IMPLICIT_CURRICULUM = False
# D = 3  # Number of indirection levels (D = 1 is the basic induction head)

num_epochs = 10000
max_layers = 8
runs_per_layer = 8

block_size = N * (D * 2 + 1)
dataset_size = 1024 * 1024
batch_size = 128


output_dim = SYMBOL_DIM * D if IMPLICIT_CURRICULUM else SYMBOL_DIM


device = "cuda" if torch.cuda.is_available() else "cpu"

os.makedirs(output_folder, exist_ok=True)

for run in range(runs_per_layer):
    for num_layers in range(1, max_layers + 1):
        output_path = f"{output_folder}/layers_{num_layers}_run_{run}.csv"

        train_dataloader = DataLoader(CustomDataset(N, D, SYMBOL_DIM, IMPLICIT_CURRICULUM, dataset_size), batch_size=batch_size, shuffle=True)
        valid_dataloader = DataLoader(CustomDataset(N, D, SYMBOL_DIM, IMPLICIT_CURRICULUM, valid_size), batch_size=valid_size, shuffle=False)

        model = Tranformer(SYMBOL_DIM, output_dim, EMBED_SIZE, N_HEADS, block_size, num_layers, mlp=mlp, head_dim=head_dim).to(device)

        optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.1)
        criterion = torch.nn.MSELoss()

        log = open(output_path, "w")

        print(f"Starting run layers={num_layers}, run={run}")

        bar = tqdm(total=num_epochs)
        global_step = 0
        for epoch in range(num_epochs + 1):
            train_loss = 0
            for x, y in train_dataloader:
                x = x.to(device)
                y = y.to(device)
                optimizer.zero_grad()
                y_pred = model(x)
                y_pred = y_pred[:, -N:]
                loss = criterion(y_pred, y)
                loss.backward()
                optimizer.step()
                train_loss += loss.item()

                global_step += batch_size
                if global_step % epoch_size == 0:
                    break

            valid_loss = 0
            position_losses = [0 for _ in range(D)]
            model.eval()
            for x, y in valid_dataloader:
                x = x.to(device)
                y = y.to(device)
                y_pred = model(x)[:, -N:]
                loss = criterion(y_pred, y)

                if IMPLICIT_CURRICULUM:
                    y_split = y.split(SYMBOL_DIM, dim=2)
                    y_pred_split = y_pred.split(SYMBOL_DIM, dim=2)

                    for i in range(D):
                        position_losses[i] += criterion(y_pred_split[i], y_split[i]).item()

                valid_loss += loss.item()
            model.train()

            train_loss /= epoch_size / batch_size
            valid_loss /= len(valid_dataloader)

            bar.set_description(f"Loss: {train_loss:.4f}, Valid loss: {valid_loss:.4f}")
            bar.update(1)

            log.write(f"{epoch},{train_loss},{valid_loss},{','.join(str(l / len(valid_dataloader)) for l in position_losses)}\n")
            log.flush()

        bar.close()
        log.close()
