import os
import numpy as np
import pandas as pd
import torch as th 
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torch.distributed as dist
from pathlib import Path
import sys
sys.path.append(str(Path.cwd())) 
import time
from Dataloader import loader
from configs import get_configs
from Diffusion import logger, dist_util
from sampling_utils import get_models_from_config,get_models_functions
from script_util import evaluate

test_path = '/home/s2263384/.cache/testset.csv'
test = pd.read_csv(test_path)
path = '/home/s2263384/.cache/with_index.csv'
tadpole_data = pd.read_csv(path)
config = get_configs.get_default_configs()
dist_util.setup_dist()
logger.configure(Path(config.experiment_name) / ("counterfactual_sampling_" + "_".join(config.classifier.label)))
logger.log("creating model and diffusion...")
classifier, diffusion, model = get_models_from_config(config)
cond_fn, model_fn, denoised_fn = get_models_functions(config, model, classifier, reg_or_class='reg')
cond_fn, model_fn, denoised_fn = get_models_functions(config, model, classifier, reg_or_class='reg', reconstruction=True)
SSIM = []
MSE = []
PSNR = []
SSIM_base= []
MSE_base = []
PSNR_base = []

for PTID in test.keys()[1:]:  
    print(PTID)
    tadpole_data = tadpole_data[tadpole_data['PTID']==PTID]
    org_data = tadpole_data[tadpole_data['age_precise']==test[PTID][0]]
    coun = tadpole_data[tadpole_data['age_precise']==test[PTID][1]]

    logger.log("creating loader...")
    
    test_loader = loader.get_data_loader(org_data, config.sampling.batch_size, split_set='test')
    data_dict = next(iter(test_loader))

    logger.log("sampling...")
    results_per_sample = {"original": ((data_dict['image']))} 
    model_kwargs = {k: v.to(dist_util.dev()) for k, v in data_dict.items()}
    init_image = data_dict['image'].to(dist_util.dev())
    model_kwargs["age"] = ( (np.array(coun['age_precise'])[0]-54.4)/38.2 * th.ones((config.sampling.batch_size,))).to(dist_util.dev())
    
    sampling_progression_ratio = 1.0
    counterfactual, label= diffusion.diffscm_counterfactual_sample(
                model_fn,
                (config.sampling.batch_size,
                    config.score_model.num_input_channels,
                    config.score_model.image_size, 
                    config.score_model.image_size),
                factual_image =init_image,
                anticausal_classifier_fn = cond_fn,
                model_kwargs=model_kwargs,
                device=dist_util.dev(),
                denoised_fn=None,
                mode = "AUTO_STOP", # ["FULL_RECORD", "FULL_ADJUSTED", "AUTO_STOP",]
                sampling_progression_ratio=sampling_progression_ratio,
                )
    
    
    counterfactual_reconstruction, _= diffusion.diffscm_counterfactual_sample(
                model_fn,
                (config.sampling.batch_size,
                    config.score_model.num_input_channels,
                    config.score_model.image_size,
                    config.score_model.image_size),
                factual_image =init_image,
                anticausal_classifier_fn = cond_fn,
                model_kwargs=model_kwargs,
                device=dist_util.dev(),
                denoised_fn=None,
                mode = "AUTO_STOP",
                sampling_progression_ratio=sampling_progression_ratio,
                )
    cf_tensor = counterfactual.cpu()
    cf_res_tensor = counterfactual_reconstruction.cpu()
    results_per_sample["counterfactual"] = counterfactual.cpu().numpy()
    results_per_sample['reconstruction'] = counterfactual_reconstruction.cpu().numpy()
    results_per_={}
    results_per_["label"] = label.cpu().detach().numpy().reshape(-1)
    results_per_['age'] = data_dict['age'].cpu().detach().numpy()
    results_per_['desired'] = model_kwargs["age"].cpu().detach().numpy()
    
    test_ = loader.get_data_loader(coun, config.sampling.batch_size, split_set='test')
    Diff=(cf_tensor-cf_res_tensor+data_dict['image']).cpu()
    data_dict_val = next(test_)
    data =  data_dict_val['image'].cpu().detach().numpy()
    data_dict_ = data_dict['image'].cpu().detach().numpy()
    
    a,b, e=evaluate(data_dict['image'], data_dict_val['image'])
    c,d, f =evaluate(Diff, data_dict_val['image'])
    
    
    SSIM.append(d)
    MSE.append(c)
    PSNR.append(f)
    SSIM_base.append(b)
    MSE_base.append(a)
    PSNR_base.append(e)

