import os
import shutil
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from model import Tranformer, plot_attention
from data import CustomDataset

N = 4  # Number of symbols per indirection level
D = 4  # Number of indirection levels (D = 1 is the basic induction head)

SYMBOL_DIM = 4
EMBED_SIZE = 512
N_HEADS = 1
N_LAYERS = 24

IMPLICIT_CURRICULUM = True

mlp = False
head_dim = 16

epoch_size = 1024
save_interval = 10

num_epochs = 1000

block_size = N * (D * 2 + 1)
dataset_size = 32
batch_size = 1

valid_size = 128

model_path = "models/model_C/3000.pt"  # 0.04818190215155482
output_folder = "attn_ablat/model_C"

model_path = "models/model_B/3000.pt"  # 0.04213446215726435
output_folder = "attn_ablat/model_B"

model_path = "models/model_A/3000.pt"  # 0.03227014990989119
output_folder = "attn_ablat/model_A"

model_path = "models/model_D/"
output_folder = "attn_ablat/model_D"

device = "cuda" if torch.cuda.is_available() else "cpu"

dataloader = DataLoader(CustomDataset(N, D, SYMBOL_DIM, IMPLICIT_CURRICULUM, dataset_size), batch_size=batch_size)

criterion = torch.nn.MSELoss()

model = Tranformer(SYMBOL_DIM, SYMBOL_DIM * D, EMBED_SIZE, N_HEADS, block_size, N_LAYERS, mlp=mlp, head_dim=head_dim).to(device)


def find_delta_loss(flow):
    total_loss = 0
    total_loss_base = 0

    for x, y in dataloader:
        num_symbols = x.shape[1]
        labels = [None for _ in range(num_symbols)]
        next_free_label = "A"
        for i in range(num_symbols):
            j = 0
            while j < i and (x[0, i] == x[0, j]).all().item() == False:
                j += 1
            if j < i:
                labels[i] = labels[j]
            else:
                labels[i] = next_free_label
                next_free_label = chr(ord(next_free_label) + 1)

        mapping = {}
        for i in range(0, num_symbols - N, 2):
            mapping[labels[i]] = labels[i + 1]

        relabel = {}
        next_free_label = "A"
        for i in range(N):
            next_letter = "A"
            s = labels[num_symbols - N + i]
            while s is not None:
                relabel[s] = (next_letter, i, next_free_label)
                next_letter = chr(ord(next_letter) + 1)
                next_free_label = chr(ord(next_free_label) + 1)
                s = mapping.get(s)

        encountered = set()
        chain_labels = []
        for i in range(num_symbols):
            name, chain, new_label = relabel.get(labels[i], (labels[i], None, None))
            name += "1" if labels[i] not in encountered else "2"
            encountered.add(labels[i])

            chain_labels.append((name, chain))
            labels[i] = new_label

        attention_maps = [[None for _ in range(N_HEADS)] for _ in range(N_LAYERS)]
        for head_id, edges in flow.items():
            layer, head = head_id.split(".")
            layer = int(layer) - 1
            head = int(head) - 1

            attention_maps[layer][head] = torch.zeros(1, num_symbols, num_symbols)

            if edges == "uniform":
                for i in range(num_symbols):
                    for j in range(i + 1):
                        attention_maps[layer][head][0, i, j] = 1.0 / (i + 1)
            elif edges == "identity":
                for i in range(num_symbols):
                    attention_maps[layer][head][0, i, i] = 1
            else:
                for dst, src in edges:
                    for i in range(num_symbols):
                        for j in range(num_symbols):
                            if chain_labels[i][0] == src and chain_labels[j][0] == dst and chain_labels[i][1] == chain_labels[j][1]:
                                attention_maps[layer][head][0, j, i] = 1

        attention_maps = [[a.to(device) if a is not None else None for a in b] for b in attention_maps]

        x = x.to(device)
        y = y.to(device)
        y_pred = model(x, attention_maps=attention_maps, labels=labels)  # , attn_path=output_folder_2
        y_pred_base = model(x)

        loss = criterion(y_pred[:, -N:], y).item()
        loss_base = criterion(y_pred_base[:, -N:], y).item()

        total_loss += loss
        total_loss_base += loss_base

    return total_loss / dataset_size, total_loss_base / dataset_size


def find_avg_attn(head_id, edge):
    layer, head = head_id.split(".")
    layer = int(layer) - 1
    head = int(head) - 1

    dst, src = edge

    obj = {
        "layer": layer,
        "edge": edge,
        "sum": 0,
        "count": 0,
    }

    expected = 0
    for x, y in dataloader:
        num_symbols = x.shape[1]
        labels = [None for _ in range(num_symbols)]
        next_free_label = "A"
        for i in range(num_symbols):
            j = 0
            while j < i and (x[0, i] == x[0, j]).all().item() == False:
                j += 1
            if j < i:
                labels[i] = labels[j]
            else:
                labels[i] = next_free_label
                next_free_label = chr(ord(next_free_label) + 1)

        mapping = {}
        for i in range(0, num_symbols - N, 2):
            mapping[labels[i]] = labels[i + 1]

        relabel = {}
        next_free_label = "A"
        for i in range(N):
            next_letter = "A"
            s = labels[num_symbols - N + i]
            while s is not None:
                relabel[s] = (next_letter, i, next_free_label)
                next_letter = chr(ord(next_letter) + 1)
                next_free_label = chr(ord(next_free_label) + 1)
                s = mapping.get(s)

        encountered = set()
        chain_labels = []
        for i in range(num_symbols):
            name, chain, new_label = relabel.get(labels[i], (labels[i], None, None))
            name += "1" if labels[i] not in encountered else "2"
            encountered.add(labels[i])

            chain_labels.append((name, chain))
            labels[i] = new_label

        attention_map = torch.zeros(1, num_symbols, num_symbols)
        for i in range(num_symbols):
            for j in range(num_symbols):
                if chain_labels[i][0] == src and chain_labels[j][0] == dst and chain_labels[i][1] == chain_labels[j][1]:
                    attention_map[0, j, i] = 1
                    expected += 1 / (j + 1)

        obj["attention_map"] = attention_map.to(device)

        x = x.to(device)
        y = y.to(device)
        y_pred = model(x, obj=obj)

    expected /= dataset_size * N
    return (obj["sum"] / obj["count"] - expected) / (1 - expected)


flow_D = {
    "1.1": "uniform",
    "2.1": "uniform",
    "3.1": "uniform",
    "4.1": "uniform",
    "5.1": "uniform",
    "6.1": "uniform",
    "7.1": "uniform",
    "8.1": "uniform",
    "9.1": "uniform",
    "10.1": "uniform",
    "11.1": "uniform",
    "12.1": "identity",
    "13.1": "identity",
    "14.1": [("E1", "D1"), ("D2", "C1"), ("C2", "B1"), ("B2", "A1")],
    "15.1": "uniform",
    "16.1": "uniform",
    "17.1": "uniform",
    "18.1": "uniform",
    "19.1": "uniform",
    "20.1": "uniform",
    "21.1": [("D2", "E1"), ("C2", "D2"), ("A2", "B2"), ("B2", "C2")],
    "22.1": [("C2", "D2"), ("A2", "B2"), ("B2", "C2")],
    "23.1": [("A2", "B2")],
    "24.1": [("A2", "C2")],
}

flow = flow_D

paths = []
for head_id, edges in flow.items():
    if isinstance(edges, list):
        for edge in edges:
            paths.append((head_id, edge))

fout = open(os.path.join(output_folder, "results.csv"), "w")

for epoch in range(0, num_epochs + 1, save_interval):
    model_path_epoch = model_path + str(epoch) + ".pt"
    model.load_state_dict(torch.load(model_path_epoch))
    model.eval()

    row = [epoch]
    for head_id, edge in paths:
        attn = find_avg_attn(head_id, edge)
        print(head_id, edge, attn)
        row.append(attn)
    fout.write(",".join(map(str, row)) + "\n")
    fout.flush()
    print(f"Epoch {epoch} done")
