import wandb
from data import *
from model import *
from argparse import Namespace
import torch as t

CONFIG = {"epochs": 50000, "lr":1e-4, "weight_decay": 0, "batch_size": 128, "log_rate":100, "cut": True} # The "cut" parameter determines whether the gradients are cut between steps (TBPTT) or not (full BPTT).
CONFIG.update({"width_extractor":64, "depth_extractor":0, "dropout_extractor":0.0})
CONFIG.update({"width_updater":64, "depth_updater":1, "dropout_updater":0.0})
CONFIG.update({"tokens":10, "sequence_length": 25, "recall_probability":(0, .3)})
	       					   # recall_probability parameter is a tuple. First number - integer, determines the number of the earliest step where a recall query is possible.

config_dict = CONFIG
CONFIG = Namespace(**CONFIG)

for s in range(1):
    t.manual_seed(s)
    config_dict["seed"] = s
    run = wandb.init(project="lstm_recall_separate", config=config_dict, reinit=True)
    with run:
        updater = LSTMUpdater(CONFIG.tokens, CONFIG.width_updater, CONFIG.depth_updater, CONFIG.dropout_updater)
        extractor = MLPExtractor(CONFIG.tokens, CONFIG.width_extractor, CONFIG.depth_extractor, CONFIG.dropout_extractor)
        updater.cuda()
        extractor.cuda()
        wandb.watch(updater)
        wandb.watch(extractor)
        dataset = SeparateDataset(CONFIG.sequence_length, CONFIG.tokens, CONFIG.recall_probability, device="cuda")
        optimizer = t.optim.AdamW(list(updater.parameters()) + list(extractor.parameters()),
                                  lr=CONFIG.lr, weight_decay=CONFIG.weight_decay)
        for i in range(CONFIG.epochs):

            optimizer.zero_grad()
            instruction, query, target = dataset.get_data(CONFIG.batch_size)

            state = updater.forward(instruction, CONFIG.cut)
            pred = extractor.forward(query, state)
            loss = t.nn.CrossEntropyLoss(reduction="none")(pred.view(-1, CONFIG.tokens + 1), target.flatten()).mean()
       
            if (i + 1) % CONFIG.log_rate == 0:
                with t.no_grad():
                    accuracy = (pred.argmax(dim=2) == target).float().mean()
                wandb.log({"loss":loss, "accuracy":accuracy}, step=i)
            loss.backward()
            optimizer.step()

