import os
import torch
import torch.nn as nn
import torch.optim as optim
from pathlib import Path
from tqdm.auto import tqdm
from accelerate import Accelerator
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from vlvq.model.hider import Hide2
from vlvq.model.finder import FindMulti
from vlvq.learner.steg import StegLearnerMulti
from vlvq.utils.dataloader import ImageAudioPair
from vlvq.logging.checkpoint import ModelCheckpoint

torch.backends.cudnn.benchmark = True
accelerator = Accelerator()

log_dir = 'log/'
os.makedirs(log_dir, exist_ok=True)
writer = SummaryWriter(log_dir=log_dir)

ic_s, ic_c = 2, 3
max_audio_len = 10

img_pth = Path('img_data')
audio_pth = Path('audio_data')

train_ds = ImageAudioPair(img_pth, audio_pth, max_audio_len, 'train')
val_ds = ImageAudioPair(img_pth, audio_pth, max_audio_len, 'val')
train_dl = DataLoader(train_ds, batch_size=16, shuffle=True, num_workers=10)
val_dl = DataLoader(val_ds, batch_size=16, shuffle=False, num_workers=10)
vis_batch = next(iter(val_dl))

hide = Hide2(ic_s, ic_c)
find = FindMulti(ic_c, ic_c, ic_s)
opt = optim.Adam(list(hide.parameters())+list(find.parameters()))

if accelerator is None:
    device = torch.device('cuda:0')
    hide, find = hide.to(device), find.to(device)
    vis_batch = [b.to(device) for b in vis_batch]
else:
    device = None
    hide = nn.SyncBatchNorm.convert_sync_batchnorm(hide)
    find = nn.SyncBatchNorm.convert_sync_batchnorm(find)
    hide, find, opt, train_dl, val_dl = accelerator.prepare(hide, find, opt, train_dl, val_dl)

learner = StegLearnerMulti(
    hide,
    find,
    max_audio_len,
    opt,
    writer,
    accelerator,
    vis_batch
)
checkpoint = ModelCheckpoint('models', hide, find, opt)

cur_log_step = 0
for cur_epoch in range(10):
    for i, batch in enumerate(tqdm(train_dl)):
        if device is not None:
            batch = [b.to(device) for b in batch]
        loss = learner.single_step(batch, cur_log_step)
        if cur_log_step % 500 == 0:
            learner.single_vis(batch, cur_log_step)
        if cur_log_step % 20000 == 0:
            learner.single_val(val_dl, cur_log_step, device)
        if cur_log_step % 1000 == 0:
            checkpoint.save(f'StegLearnerMulti-{cur_log_step}.pth')
        cur_log_step += 1