import os
import torch
import torch.nn as nn
import torch.optim as optim
from pathlib import Path
from tqdm.auto import tqdm
from accelerate import Accelerator
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from vlvq.model.hider import Hide2
from vlvq.model.finder import Find
from vlvq.learner.steg import StegLearner
from vlvq.utils.dataloader import ImageNetPair
from vlvq.logging.checkpoint import ModelCheckpoint

torch.backends.cudnn.benchmark = True
accelerator = Accelerator()

log_dir = 'log/'
os.makedirs(log_dir, exist_ok=True)
writer = SummaryWriter(log_dir=log_dir)

ic_s, ic_c = 3, 3
hide = Hide2(ic_s, ic_c)
find = Find(ic_c, ic_s)

hide = nn.SyncBatchNorm.convert_sync_batchnorm(hide)
find = nn.SyncBatchNorm.convert_sync_batchnorm(find)

pth = Path('data')
train_ds = ImageNetPair(pth, 'train')
val_ds = ImageNetPair(pth, 'val')
train_dl = DataLoader(train_ds, batch_size=16, shuffle=True, num_workers=10)
val_dl = DataLoader(val_ds, batch_size=16, shuffle=False, num_workers=10)
vis_batch = next(iter(val_dl))

opt = optim.Adam(list(hide.parameters())+list(find.parameters()))
hide, find, opt, train_dl, val_dl = accelerator.prepare(hide, find, opt, train_dl, val_dl)

learner = StegLearner(
    hide,
    find,
    opt,
    writer,
    accelerator,
    vis_batch
)
checkpoint = ModelCheckpoint('models', hide, find, opt)

cur_log_step = 0
for cur_epoch in range(10):
    for i, batch in enumerate(tqdm(train_dl)):
        loss = learner.single_step(batch, cur_log_step)
        if cur_log_step % 500 == 0:
            learner.single_vis(batch, cur_log_step)
        if cur_log_step % 20000 == 0:
            learner.single_val(val_dl, cur_log_step)
        if cur_log_step % 1000 == 0:
            checkpoint.save(f'StegLearner-{cur_log_step}.pth')
        cur_log_step += 1