import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch as th
from mpi4py import MPI
from PIL import Image
from Dataloader import loader
from pathlib import Path
import sys
sys.path.append(str(Path.cwd()))
from configs import get_configs
from Diffusion import logger, dist_util
import Diffusion.dist_util
from Diffusion.train_util import TrainLoop
from Diffusion.unet import UNetModel, EncoderUNetModel
from Diffusion.resample import create_named_schedule_sampler
import os
from script_util import create_gaussian_diffusion, create_score_model_, create_image_cond_score_model



def main():
    os.environ["CUDA_VISIBLE_DEVICES"] = "3"
    config = get_configs.get_default_configs()
    Diffusion.dist_util.setup_dist()
    print(Diffusion.dist_util.dev())
    logger.configure(Path(config.experiment_name)/"score_train",
                     format_strs=["log", "stdout", "csv", "tensorboard"])
    
    logger.log("creating data loader...")

    tadpole_data = pd.read_csv('/home/s2263384/.cache/cross_training_set.csv')
    # no condtioning on x
    train_loader = loader.get_data_loader(tadpole_data, config.score_model.training.batch_size, split_set='train')
    val_loader = loader.get_data_loader(tadpole_data, config.score_model.training.batch_size, split_set='val')
    
    
    logger.log("creating model and diffusion...")
    
    model = create_score_model_(config)
    diffusion = create_gaussian_diffusion(config)
    model.to(Diffusion.dist_util.dev())
    schedule_sampler = create_named_schedule_sampler(config.score_model.training.schedule_sampler, diffusion)
  
  
    logger.log("training...")
    TrainLoop(
        model=model,
        diffusion=diffusion,
        data=train_loader,
        data_val=val_loader,
        batch_size=config.score_model.training.batch_size,
        microbatch=config.score_model.training.microbatch,  
        lr=config.score_model.training.lr,
        ema_rate=config.score_model.training.ema_rate,
        log_interval=config.score_model.training.log_interval,
        save_interval=config.score_model.training.save_interval,
        resume_checkpoint=config.score_model.training.resume_checkpoint,
        use_fp16=config.score_model.training.use_fp16,
        fp16_scale_growth=config.score_model.training.fp16_scale_growth,
        schedule_sampler=schedule_sampler, 
        weight_decay=config.score_model.training.weight_decay,
        lr_anneal_steps=config.score_model.training.lr_anneal_steps,
    ).run_loop(config.score_model.training.iterations)

if __name__ == "__main__":
    main()
