# ===================================================================
#
#   Code for synthetic data
#   Example: 
#           python main.py --exp baseline
#           python main.py --exp sparse
#
# ===================================================================
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

from model import Transformer
from dataset import ThreeRowShapeDataset
from utils import save_example_plot

parser = argparse.ArgumentParser(description='our_model')
parser.add_argument('--exp', type=str, default='baseline')
args = parser.parse_args()

def main():
    ## ---------------- Model configs -------------------
    patch_size=4
    seq_len = (8 // patch_size) ** 2 * 8
    model = Transformer(seq_len=seq_len, d_model=16, n_heads=1, patch_size=patch_size, num_layers=1, dropout=0., model_type=args.exp)
    model = model.cuda()
    
    cnt = 0
    for n, p in model.named_parameters():
        cnt += p.numel()
    print("params: ", cnt)

    ## ---------------- Dataset configs -------------------
    dataset = ThreeRowShapeDataset(1000, stype='train')
    dataset_tst = ThreeRowShapeDataset(100, stype='test')
    batch_size = 128
    trainloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=16)
    testloader = DataLoader(dataset_tst, batch_size=100, shuffle=False, num_workers=16)
    
    ## ---------------- Training configs -------------------
    epoch = 200
    lr = 1e-3
    beta1 = 0.9
    beta2 = 0.999
    epsilon = 1e-8
    optimizer = optim.Adam(model.parameters(), lr=lr, betas=(beta1, beta2), eps=epsilon)
    criterion = nn.MSELoss()
    
    ## ---------------- Training loop -------------------
    for e in range(epoch):
        tra_loss = 0
        tra_cnt = 0
        for i, (img, tgt) in enumerate(trainloader):
            img = img.cuda()
            tgt = tgt.cuda()
            
            optimizer.zero_grad()
            pred, diff = model(img)
            loss = criterion(pred, tgt)
            loss.backward()
            optimizer.step()
            tra_loss += loss.item()
            tra_cnt += 1

        if e % 10 == 0:
            tst_loss = 0
            tst_cnt = 0
            with torch.no_grad():
                for i, (img, tgt) in enumerate(testloader):
                    img = img.cuda()
                    tgt = tgt.cuda()

                    pred, sparsity = model(img)
                    loss = criterion(pred, tgt)
                    tst_loss += loss.item()
                    tst_cnt += 1
            print("Epoch: {}, Train loss:{:.4f}, Test loss:{:.4f}, Sparsity: {:.4f}".format(e, tra_loss / tra_cnt, tst_loss / tst_cnt, sparsity))
    
    ## ---------------- Visualization -------------------
    print("Saving output figures...")
    save_example_plot(model, args.exp)
    print("Done!")


if __name__ == '__main__':
    main()