import argparse
import random

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from dataset import JSBDataset
from models import FCNN

parser = argparse.ArgumentParser()
parser.add_argument("--epochs", default=100, type=int)
parser.add_argument("--lr", default=0.001, type=float)
parser.add_argument("--seq_len", default=4, type=int)
parser.add_argument("--batch_size", default=32, type=int)
parser.add_argument("--trials", default=10, type=int)
parser.add_argument("--idx", default=1, type=int)
args = parser.parse_args()
args.idx -= 1  # array task index is 1-indexed

HIDDEN = [32, 64, 128, 256, 512, 1024, 2048]
ACTFUN = [
    "relu",
    "prelu",
    "maxout",
    "max_min_dup",
    "signedgeomean",
    "ail_and",
    "ail_xnor",
    "ail_or",
    "ail_and_or_dup",
    "ail_or_xnor_part",
    "ail_or_xnor_dup",
    "ail_and_or_xnor_part",
    "ail_and_or_xnor_dup",
]
h = HIDDEN[args.idx % len(HIDDEN)]
a = ACTFUN[int(args.idx // len(HIDDEN))]

fname = "Jsb16thSeparated.json"
train_data = JSBDataset(fname, seq_len=args.seq_len, num_tokens=37, split="train")
train_loader = DataLoader(
    train_data, batch_size=args.batch_size, shuffle=True, drop_last=True
)
test_data = JSBDataset(fname, seq_len=args.seq_len, num_tokens=37, split="test")
test_loader = DataLoader(
    train_data, batch_size=args.batch_size, shuffle=False, drop_last=True
)

nc_in = train_data[0][0].flatten().size(0)
model = FCNN(nc_in, nc_hidden=h, actfun=a)
print(h, a, flush=True)
print(
    "Parameters:",
    sum(p.numel() for p in model.parameters() if p.requires_grad),
    flush=True,
)

for k in range(args.trials):

    model = FCNN(nc_in, nc_hidden=h, actfun=a).cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    criterion = nn.CrossEntropyLoss()

    model.train()
    for epoch in range(1, args.epochs + 1):
        for batch_idx, batch in enumerate(train_loader, 1):
            inp = batch[0].cuda()
            tgt = batch[1].cuda()

            optimizer.zero_grad()
            out = model(inp.flatten(start_dim=1))
            loss = criterion(out, tgt)
            loss.backward()
            optimizer.step()

    model.eval()
    test_loss = 0.0
    test_acc = 0.0
    with torch.no_grad():
        for batch_idx, batch in enumerate(test_loader, 1):
            inp = batch[0].cuda()
            tgt = batch[1].cuda()
            out = model(inp.flatten(start_dim=1))
            loss = criterion(out, tgt)
            test_acc += (
                out.detach().cpu().argmax(dim=1) == tgt.cpu()
            ).float().sum() / args.batch_size
            test_loss += loss.detach().cpu()
    print(
        "%.4f,%.4f" % (test_acc / len(test_loader), test_loss / len(test_loader)),
        flush=True,
    )
