import wandb
from data import Dataset
from model import LSTMClassifier
from argparse import Namespace
import torch as t

CONFIG = {"epochs": 50000, "lr": 1e-4, "lrout": 1e-4, "weight_decay": 0, "batch_size": 12, "log_rate":100, "cut": True, "ext_to_upd_udates_ratio": 1}
CONFIG.update({"width":64, "depth": 1, "dropout": 0.0})
CONFIG.update({"tokens":10, "sequence_length": 10, "recall_probability":(0, .3), "reminder_probability":(0, 0.05)})
                                            # recall_probability parameter is a tuple. First number - integer, determines the number of the earliest step where a recall query is possible.
                                            # reminder_probability parameter is a tuple. First number - integer, determines the number of the earliest step where reminder noise injection is possible.
config_dict = CONFIG
CONFIG = Namespace(**CONFIG)

for sequence_length in [10]:

    config_dict["sequence_length"] = sequence_length
    CONFIG.sequence_length = sequence_length

    CE_loss_fn = t.nn.CrossEntropyLoss(reduction="none")

    run = wandb.init(project="lstm_recall_short_summary", config=config_dict, reinit=True)
    with run:
        classifier = LSTMClassifier(CONFIG.tokens, CONFIG.width, CONFIG.depth, CONFIG.dropout)
        classifier.cuda()
        wandb.watch(classifier)
        dataset = Dataset(CONFIG.sequence_length, CONFIG.tokens, CONFIG.recall_probability,
                          reminder_probability=CONFIG.reminder_probability, device="cuda")
        
        optimizer = t.optim.AdamW([p[1] for p in classifier.named_parameters() if p[0] not in {"output.weight", "output.bias"}], lr=CONFIG.lr, weight_decay=CONFIG.weight_decay)
        optimizer_out = t.optim.AdamW([p[1] for p in classifier.named_parameters() if p[0] in {"output.weight", "output.bias"}], lr=CONFIG.lrout, weight_decay=CONFIG.weight_decay)
        for i in range(CONFIG.epochs):

            optimizer.zero_grad()
            optimizer_out.zero_grad()
            sequence, target = dataset.get_data(CONFIG.batch_size)

            pred = classifier.forward(sequence, CONFIG.cut)

            loss = CE_loss_fn(pred.view(-1, CONFIG.tokens + 1), target.flatten()).mean()

            if (i + 1) % CONFIG.log_rate == 0:

                with t.no_grad():
                    accuracy = (pred[:, -1, :].argmax(dim=1) == target[:, -1]).float().mean()
                wandb.log({"loss":loss, "accuracy":accuracy}, step=i)
            loss.backward()

            if i % CONFIG.ext_to_upd_udates_ratio == 0:
                optimizer.step()

            optimizer_out.step()

