import hydra
import pandas as pd
import sys
from env.mfenv import MultiFidelityEnvWrapper
from sklearn.metrics import explained_variance_score

@hydra.main(config_path="./oracle_eval", config_name="/oracle_eval_cfg")
def main(config):
    dataset = pd.read_csv(config.dataset)
    samples = dataset['samples']
    oracles = []
    env = hydra.utils.instantiate(
        config.env,
        device=config.device,
        float_precision=config.float_precision,
    )
    env.tokenizer = None
    states = [env.readable2state(s) for s in samples]
    states_oracle = env.statebatch2oracle(states)
    N_FID = len(config._oracle_dict)

    for fid in range(1, N_FID + 1):
        oracle = hydra.utils.instantiate(
            config._oracle_dict[str(fid)],
            env=env,
            device=config.device,
            float_precision=config.float_precision,
        )
        oracles.append(oracle)

    for fid in range(0, N_FID):
        scores = oracles[fid](states_oracle)
        setattr(sys.modules[__name__], "y_{}".format(fid), scores)
    
    for fid in range(0, N_FID):
        y_curr_fid = getattr(sys.modules[__name__], "y_{}".format(fid))
        y_max_fid = getattr(sys.modules[__name__], "y_{}".format(N_FID-1))
        y_current = y_curr_fid.detach().cpu().numpy()
        y_max = y_max_fid.detach().cpu().numpy()
        score = explained_variance_score(y_max, y_current, force_finite=True)
        print("Explained variance score for oracle {} is {}".format(fid+1, score))



if __name__ == "__main__":
    main()
    sys.exit()