"""
Like score_sampling.py, but use a noisy image classifier to guide the sampling
process towards more realistic images.
"""
import os
import numpy as np
import pandas as pd
import torch as th
import matplotlib.pyplot as plt
import torch.distributed as dist
from pathlib import Path
import sys
sys.path.append(str(Path.cwd()))

from Dataloader import loader
from configs import get_configs
from Diffusion import logger, dist_util
from sampling_utils import get_models_from_config,get_models_functions



def main():
    config = get_configs.get_default_configs()

    dist_util.setup_dist()
    logger.configure(Path(config.experiment_name) / ("counterfactual_sampling_" + "_".join(config.classifier.label)))

    logger.log("creating loader...")
    tadpole_data = pd.read_csv(config.data.ANDI_path)
    
    test_loader = loader.get_data_loader(tadpole_data, config.sampling.batch_size, split_set='test')
    #if config.sampling.image_conditional:
    #    test_loader = loader.get_img_cond_dataloader(tadpole_data, config.sampling.batch_size, split_set='test')

    logger.log("creating model and diffusion...")

    classifier, diffusion, model = get_models_from_config(config)

    cond_fn, model_fn = get_models_functions(config, model, classifier)
    data_dict = next(test_loader)
  
    logger.log("sampling...")
    
    results_per_sample = {"original": ((data_dict['image'] + 1) * 127.5).clamp(0, 255).to(
    th.uint8)}
# send data points to GPU
    model_kwargs = {k: v.to(dist_util.dev()) for k, v in data_dict.items()}
    init_image = data_dict['image'].to(dist_util.dev())
# create counterfactual target 
    model_kwargs["age"] = (config.sampling.counterfactual_class * th.ones((config.sampling.batch_size,))).to(th.long).to(dist_util.dev())

    counterfactual = diffusion.diffscm_counterfactual_sample(
                model_fn,
                (config.sampling.batch_size,
                    config.score_model.num_input_channels,
                    config.score_model.image_size,
                    config.score_model.image_size),
                factual_image = init_image,
                anticausal_classifier_fn = cond_fn,
                model_kwargs=model_kwargs,
                device=dist_util.dev(),
                )

    counterfactual = ((counterfactual + 1) * 127.5).clamp(0, 255).to(th.uint8)
    results_per_sample["counterfactual"] = counterfactual.cpu().numpy()
    
    out_path = os.path.join(logger.get_dir(), f"samples_{config.sampling.label_of_intervention}.npz")
    logger.log(f"saving to {out_path}")
    np.savez(out_path, results_per_sample)
    dist.barrier()
    logger.log("sampling complete")
    


if __name__ == "__main__":
    main()