import os
import shutil
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from model import Tranformer
from data import CustomDataset

### MODEL A

N = 4
D = 3

SYMBOL_DIM = 4
EMBED_SIZE = 128
N_HEADS = 1
N_LAYERS = 12

IMPLICIT_CURRICULUM = True

mlp = True
head_dim = None

block_size = N * (D * 2 + 1)
dataset_size = 1
batch_size = 64

valid_size = 128

model_path = "models/model_C/3000.pt"
output_folder = "attn_viz/model_C_pdf"

### MODEL D

N = 4
D = 4

SYMBOL_DIM = 4
EMBED_SIZE = 512
N_HEADS = 1
N_LAYERS = 24

IMPLICIT_CURRICULUM = True

mlp = False
head_dim = 16

block_size = N * (D * 2 + 1)
dataset_size = 1
batch_size = 64

valid_size = 128

model_path = "models/model_D/1000.pt"
output_folder = "attn_viz/model_D"

device = "cuda" if torch.cuda.is_available() else "cpu"

dataloader = DataLoader(CustomDataset(N, D, SYMBOL_DIM, IMPLICIT_CURRICULUM, dataset_size), batch_size=batch_size)

output_dim = SYMBOL_DIM * D if IMPLICIT_CURRICULUM else SYMBOL_DIM
model = Tranformer(SYMBOL_DIM, output_dim, EMBED_SIZE, N_HEADS, block_size, N_LAYERS, mlp=mlp, head_dim=head_dim).to(device)
model.load_state_dict(torch.load(model_path))
model.eval()

if os.path.exists(output_folder):
    shutil.rmtree(output_folder)
os.makedirs(output_folder, exist_ok=True)

for x, y in dataloader:
    num_symbols = x.shape[1]
    labels = [None for _ in range(num_symbols)]
    next_free_label = "A"
    for i in range(num_symbols):
        j = 0
        while j < i and (x[0, i] == x[0, j]).all().item() == False:
            j += 1
        if j < i:
            labels[i] = labels[j]
        else:
            labels[i] = next_free_label
            next_free_label = chr(ord(next_free_label) + 1)

    mapping = {}
    for i in range(0, num_symbols - N, 2):
        mapping[labels[i]] = labels[i + 1]

    relabel = {}
    next_free_label = "A"
    for i in range(N):
        s = labels[num_symbols - N + i]
        while s is not None:
            relabel[s] = next_free_label
            next_free_label = chr(ord(next_free_label) + 1)
            s = mapping.get(s)

    labels = [relabel.get(l, l) for l in labels]

    x = x.to(device)
    y = y.to(device)
    y_pred = model(x, output_folder, labels)
