import os
import shutil
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from model import Tranformer
from data import CustomDataset

N = 4  # Number of symbols per indirection level
D = 3  # Number of indirection levels (D = 1 is the basic induction head)

SYMBOL_DIM = 4
EMBED_SIZE = 128
N_HEADS = 1
N_LAYERS = 12

IMPLICIT_CURRICULUM = True

mlp = True
head_dim = None

epoch_size = 1024
save_interval = 50

num_epochs = 5000

block_size = N * (D * 2 + 1)
dataset_size = 16
batch_size = 1

valid_size = 128

model_path = "models/model_D/500.pt"  # Total loss: 0.04475724999792874, Base loss: 0.03099941398249939
output_folder = "attn_ablat/model_D"

model_path = "models/model_A/3000.pt"  # 0.03227014990989119
output_folder = "attn_ablat/model_A"

model_path = "models/model_B/3000.pt"  # 0.04213446215726435
output_folder = "attn_ablat/model_B"

model_path = "models/model_C/3000.pt"  # 0.04818190215155482
output_folder = "attn_ablat/model_C"

device = "cuda" if torch.cuda.is_available() else "cpu"

dataloader = DataLoader(CustomDataset(N, D, SYMBOL_DIM, IMPLICIT_CURRICULUM, dataset_size), batch_size=batch_size)

criterion = torch.nn.MSELoss()

model = Tranformer(SYMBOL_DIM, SYMBOL_DIM * D, EMBED_SIZE, N_HEADS, block_size, N_LAYERS, mlp=mlp, head_dim=head_dim).to(device)
model.load_state_dict(torch.load(model_path))
model.eval()


flow_A = {
    "1.1": "identity",
    "2.1": "uniform",
    "3.1": "uniform",
    "4.1": "uniform",
    "5.1": "uniform",
    "6.1": "uniform",
    "7.1": "uniform",
    "8.1": [("D1", "C1"), ("C2", "B1"), ("B2", "A1")],
    "9.1": [("C2", "D1")],
    "10.1": [("B2", "C2")],
    "11.1": "uniform",
    "12.1": [("A2", "B2")],
}

flow_B = {
    "1.1": "uniform",
    "2.1": "uniform",
    "3.1": "uniform",
    "4.1": "uniform",
    "5.1": [("D1", "C1"), ("C2", "B1"), ("B2", "A1")],
    "6.1": "uniform",
    "7.1": [("C2", "D1")],
    "8.1": "uniform",
    "9.1": "uniform",
    "10.1": "uniform",
    "11.1": [("A2", "B2")],
    "12.1": [("A2", "C2")],
}

flow_C = {
    "1.1": "uniform",
    "2.1": "uniform",
    "3.1": "uniform",
    "4.1": "uniform",
    "5.1": "uniform",
    "6.1": "uniform",
    "7.1": [("D1", "C1"), ("C2", "B1"), ("B2", "A1")],
    "8.1": "uniform",
    "9.1": "uniform",
    "10.1": [("B2", "C2")],
    "11.1": [("A2", "B2")],
    "12.1": [("A2", "D1")],
}

flow_D = {
    "1.1": "uniform",
    "2.1": "uniform",
    "3.1": "uniform",
    "4.1": "uniform",
    "5.1": "uniform",
    "6.1": "uniform",
    "7.1": "uniform",
    "8.1": "uniform",
    "9.1": "uniform",
    "10.1": "uniform",
    "11.1": "uniform",
    "12.1": "identity",
    "13.1": "identity",
    "14.1": [("E1", "D1"), ("D2", "C1"), ("C2", "B1"), ("B2", "A1")],
    "15.1": "uniform",
    "16.1": "uniform",
    "17.1": "uniform",
    "18.1": "uniform",
    "19.1": "uniform",
    "20.1": "uniform",
    "21.1": [("D2", "E1"), ("C2", "D2"), ("A2", "B2"), ("B2", "C2")],
    "22.1": [("C2", "D2"), ("A2", "B2"), ("B2", "C2")],
    "23.1": [("A2", "B2")],
    "24.1": [("A2", "C2")],
}

flow = flow_C

output_folder_2 = output_folder + "_ablation"

if os.path.exists(output_folder):
    shutil.rmtree(output_folder)
os.makedirs(output_folder, exist_ok=True)
if os.path.exists(output_folder_2):
    shutil.rmtree(output_folder_2)
os.makedirs(output_folder_2, exist_ok=True)

total_loss = 0
total_loss_base = 0

pbar = tqdm(total=dataset_size)

for x, y in dataloader:
    num_symbols = x.shape[1]
    labels = [None for _ in range(num_symbols)]
    next_free_label = "A"
    for i in range(num_symbols):
        j = 0
        while j < i and (x[0, i] == x[0, j]).all().item() == False:
            j += 1
        if j < i:
            labels[i] = labels[j]
        else:
            labels[i] = next_free_label
            next_free_label = chr(ord(next_free_label) + 1)

    mapping = {}
    for i in range(0, num_symbols - N, 2):
        mapping[labels[i]] = labels[i + 1]

    relabel = {}
    next_free_label = "A"
    for i in range(N):
        next_letter = "A"
        s = labels[num_symbols - N + i]
        while s is not None:
            relabel[s] = (next_letter, i, next_free_label)
            next_letter = chr(ord(next_letter) + 1)
            next_free_label = chr(ord(next_free_label) + 1)
            s = mapping.get(s)

    encountered = set()
    chain_labels = []
    for i in range(num_symbols):
        name, chain, new_label = relabel.get(labels[i], (labels[i], None, None))
        name += "1" if labels[i] not in encountered else "2"
        encountered.add(labels[i])

        chain_labels.append((name, chain))
        labels[i] = new_label

    attention_maps = [[None for _ in range(N_HEADS)] for _ in range(N_LAYERS)]
    for head_id, edges in flow.items():
        layer, head = head_id.split(".")
        layer = int(layer) - 1
        head = int(head) - 1

        attention_maps[layer][head] = torch.zeros(1, num_symbols, num_symbols)

        if edges == "uniform":
            for i in range(num_symbols):
                for j in range(i + 1):
                    attention_maps[layer][head][0, i, j] = 1.0 / (i + 1)
        elif edges == "identity":
            for i in range(num_symbols):
                attention_maps[layer][head][0, i, i] = 1
        else:
            for dst, src in edges:
                for i in range(num_symbols):
                    for j in range(num_symbols):
                        if chain_labels[i][0] == src and chain_labels[j][0] == dst and chain_labels[i][1] == chain_labels[j][1]:
                            attention_maps[layer][head][0, j, i] = 1

    attention_maps = [[a.to(device) if a is not None else None for a in b] for b in attention_maps]

    x = x.to(device)
    y = y.to(device)
    y_pred = model(x, attention_maps=attention_maps, labels=labels)
    y_pred_base = model(x)

    loss = criterion(y_pred[:, -N:], y).item()
    loss_base = criterion(y_pred_base[:, -N:], y).item()

    pbar.set_description(f"Loss: {loss:.4f}, Base loss: {loss_base:.4f}")
    pbar.update(1)

    total_loss += loss
    total_loss_base += loss_base

pbar.close()

print(f"Total loss: {total_loss / dataset_size}, Base loss: {total_loss_base / dataset_size}")
