
import gc
import io
import os
import time

import numpy as np
import logging
# Keep the import below for registering all model definitions
from models import ddpm, ncsnv2, ncsnpp, classifier
from models import utils as mutils
from models.ema import ExponentialMovingAverage
import datasets2 as datasets
import sde_lib
from absl import flags
import torch
from torch.utils import tensorboard
from torchvision.utils import make_grid, save_image
from utils import save_checkpoint, restore_checkpoint

import pytorch_lightning as pl
import score_model

from pytorch_lightning.callbacks import ModelCheckpoint
import pickle
FLAGS = flags.FLAGS

def train(config, workdir):
    f = open(f"{workdir}/config.pkl","wb")
    pickle.dump(config,f)
    f.close()
    
    config = config.unlock()
    model = score_model.ScoreModel(config,workdir)    
    ckpt_file = os.path.join(workdir,"last.ckpt")
    print(not(os.path.exists(ckpt_file)),config.training.clf_model,not(config.training.score_model))
    if not(os.path.exists(ckpt_file)) and config.training.score_path!='':
        model.load_pretrained_state_dict(torch.load(config.training.score_path))
    elif not(os.path.exists(ckpt_file)) and config.training.clf_model and not(config.training.score_model):
        score_ckpt = f"scoredirs/{config.data.dataset.lower()}_{config.training.sde[:2]}/last.ckpt"
        ckpt = torch.load(score_ckpt)
        sd = ckpt['state_dict']
        real_sd = dict()
        for k in filter(lambda x:x.startswith('score_model'),list(sd.keys())):
            real_sd[k] = sd[k]
        model.load_state_dict(real_sd,strict=False)
        model.ema.load_state_dict(ckpt['ema'])
        print("Loaded score-weights and ema from disk")
    model.adjust_parameters()
    train_loader, test_loader = datasets.get_dataset(config, uniform_dequantization=config.data.uniform_dequantization)
    print(config)
    logger = pl.loggers.TensorBoardLogger(save_dir=workdir,name="")
    
    checkpoint_callback = ModelCheckpoint(
            dirpath=workdir,
            save_last=True
        )
    checkpoint_callback2 = ModelCheckpoint(dirpath=workdir, save_top_k=2, monitor="acc_v_epoch" if config.training.clf_model else "unc_loss", mode="max")
    trainer = pl.Trainer(devices=torch.cuda.device_count(),
                     accelerator='gpu',  
                     strategy=pl.strategies.DDPStrategy(find_unused_parameters=False),
                     num_sanity_val_steps=1,
                     logger=logger,
                     benchmark=True,
                     sync_batchnorm=False,
                     callbacks=[checkpoint_callback, checkpoint_callback2],
                     resume_from_checkpoint=ckpt_file if os.path.exists(ckpt_file) else None,
                     max_steps=config.training.n_iters)
    trainer.fit(model,train_dataloaders=train_loader,val_dataloaders=test_loader)
    
