import os 
import numpy as np
import torch
import argparse
import os
from utils.rogi_xd import rogi, Metric

import warnings
warnings.filterwarnings("ignore")

parser = argparse.ArgumentParser()
parser.add_argument("--randseed", type=int, default=1)
args = parser.parse_args()

np.random.seed(args.randseed)
torch.manual_seed(args.randseed)

rogi_dict = {
    "redox-mer": {
        "fingerprints": {
            "tanimoto": {"vals": [], "mean": 0., "std": 0. }
        }, 
        "molformer": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }
        }, 
        "t5-base-chem": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }, 
        }, 
        "mordred": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }, 
        }, 
        "degree_of_conjugation": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }, 
        }, 
        "force_field": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }, 
        },
        "dft": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }, 
        },
        "all_features": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }, 
        },
        "data_driven": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }, 
        },
        "hand_crafted_expert": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }, 
        },
        "hand_crafted_general": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }, 
        },
    }, 
    "solvation": {
        "fingerprints": {
            "tanimoto": {"vals": [], "mean": 0., "std": 0. }
        }, 
        "molformer": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }
        }, 
        "t5-base-chem": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }, 
        }, 
        "mordred": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }, 
        }, 
        "degree_of_conjugation": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }, 
        }, 
        "force_field": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }, 
        },
        "dft": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }, 
        },
        "all_features": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }, 
        },
        "data_driven": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }, 
        },
        "hand_crafted_expert": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }, 
        },
        "hand_crafted_general": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }, 
        },
    }, 
    "kinase": {
        "fingerprints": {
            "tanimoto": {"vals": [], "mean": 0., "std": 0. }
        }, 
        "molformer": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }
        }, 
        "t5-base-chem": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }, 
        }, 
        "mordred": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }, 
        }, 
        "degree_of_conjugation": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }, 
        }, 
        "force_field": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }, 
        },
        "dft": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }, 
        },
        "all_features": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }, 
        },
        "data_driven": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }, 
        },
        "hand_crafted_expert": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }, 
        },
        "hand_crafted_general": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }, 
        },
    }, 
    "laser": {
        "fingerprints": {
            "tanimoto": {"vals": [], "mean": 0., "std": 0. }
        }, 
        "molformer": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }
        }, 
        "t5-base-chem": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }, 
        }, 
        "mordred": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }, 
        }, 
        "degree_of_conjugation": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }, 
        }, 
        "force_field": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }, 
        },
        "dft": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }, 
        },
        "all_features": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }, 
        },
        "data_driven": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }, 
        },
        "hand_crafted_expert": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }, 
        },
        "hand_crafted_general": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }, 
        },
    }, 
    "pce": {
       "fingerprints": {
            "tanimoto": {"vals": [], "mean": 0., "std": 0. }
        }, 
        "molformer": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }
        }, 
        "t5-base-chem": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }, 
        }, 
        "mordred": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }, 
        }, 
        "degree_of_conjugation": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }, 
        }, 
        "force_field": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }, 
        },
        "dft": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }, 
        },
        "all_features": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }, 
        },
        "data_driven": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }, 
        },
        "hand_crafted_expert": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }, 
        },
        "hand_crafted_general": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }, 
        },
    }, 
    "photoswitch": {
        "fingerprints": {
            "tanimoto": {"vals": [], "mean": 0., "std": 0. }
        }, 
        "molformer": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }
        }, 
        "t5-base-chem": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }, 
        }, 
        "mordred": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }, 
        }, 
        "degree_of_conjugation": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }, 
        }, 
        "force_field": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }, 
        },
        "dft": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }, 
        },
        "all_features": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }, 
        },
        "data_driven": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }, 
        },
        "hand_crafted_expert": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }, 
        },
        "hand_crafted_general": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }, 
        },
    },
    "ampc": {
        "fingerprints": {
            "tanimoto": {"vals": [], "mean": 0., "std": 0. }
        }, 
        "molformer": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }
        }, 
        "t5-base-chem": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }, 
        }, 
        "mordred": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }, 
        }, 
        "degree_of_conjugation": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }, 
        }, 
        "force_field": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }, 
        },
        "dft": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }, 
        },
        "all_features": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }, 
        },
        "data_driven": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }, 
        },
        "hand_crafted_expert": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }, 
        },
        "hand_crafted_general": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }, 
        },
    },  
    "d4": {
        "fingerprints": {
            "tanimoto": {"vals": [], "mean": 0., "std": 0. }
        }, 
        "molformer": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }
        }, 
        "t5-base-chem": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }, 
        }, 
        "mordred": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }, 
        }, 
        "degree_of_conjugation": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }, 
        }, 
        "force_field": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }, 
        },
        "dft": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }, 
        },
        "all_features": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }, 
        },
        "data_driven": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }, 
        },
        "hand_crafted_expert": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }, 
        },
        "hand_crafted_general": {
            "euclidean": {"vals": [], "mean": 0., "std": 0. }, 
        },
    }
}

feature_type_list = [
    "fingerprints", "molformer", "mordred", "t5-base-chem", "dft", "force_field", "degree_of_conjugation", "data_driven", "hand_crafted_expert", "hand_crafted_general", "all_features"
]
for problem in ["redox-mer", "solvation", "kinase", "laser", "pce", "photoswitch", "ampc", "d4"]:
    for feature_type in feature_type_list:
        print()
        print(f"Problem: {problem} - Feature type: {feature_type}")
        print("--------------------------------------------------")
        # Dataset
        CACHE_PATH = f"data/cache/{problem}/"
        if feature_type == "all_features":
            feature_list = ["fingerprints", "molformer", "t5-base-chem", "mordred", "degree_of_conjugation", "force_field", "dft"]
            targets = torch.load(CACHE_PATH + f"fingerprints_targets.bin")
            features = []
            for feature in feature_list:
                features.append(torch.load(CACHE_PATH + f"{feature}_feats.bin"))
            features = [torch.cat([fp, mf, t5, md, dc, ff, dft], dim=0) for fp, mf, t5, md, dc, ff, dft in zip(*features)]
        elif feature_type == "data_driven":
            feature_list = ["molformer", "t5-base-chem"]
            targets = torch.load(CACHE_PATH + f"fingerprints_targets.bin")
            features = []
            for feature in feature_list:
                features.append(torch.load(CACHE_PATH + f"{feature}_feats.bin"))
            features = [torch.cat([mf, t5], dim=0) for mf, t5 in zip(*features)]
        elif feature_type == "hand_crafted_expert":
            feature_list = ["degree_of_conjugation", "force_field", "dft"]
            targets = torch.load(CACHE_PATH + f"fingerprints_targets.bin")
            features = []
            for feature in feature_list:
                features.append(torch.load(CACHE_PATH + f"{feature}_feats.bin"))
            features = [torch.cat([dc, ff, dft], dim=0) for dc, ff, dft in zip(*features)]
        elif feature_type == "hand_crafted_general":
            feature_list = ["fingerprints", "mordred"]
            targets = torch.load(CACHE_PATH + f"fingerprints_targets.bin")
            features = []
            for feature in feature_list:
                features.append(torch.load(CACHE_PATH + f"{feature}_feats.bin"))
            features = [torch.cat([fp, md], dim=0) for fp, md in zip(*features)]
        else:
            features = torch.load(CACHE_PATH + f"{feature_type}_feats.bin")
            targets = torch.load(CACHE_PATH + f"{feature_type}_targets.bin")
        
        train_x, train_y = np.stack(features), np.stack(targets)

        metric = [Metric.EUCLIDEAN]
        if feature_type == "fingerprints":
            metric = [Metric.TANIMOTO] 

        for m in metric:
            rogi_score = rogi(train_x, train_y, normalize=True, metric=m, nboots=10)
            print(f"Metric: {m} - ROGI score: {rogi_score.rogi} +/- {rogi_score.uncertainty}", flush=True)
            rogi_dict[problem][feature_type][str(m)]["vals"].append(rogi_score.boot_scores)
            rogi_dict[problem][feature_type][str(m)]["mean"] = rogi_score.rogi
            rogi_dict[problem][feature_type][str(m)]["std"] = rogi_score.uncertainty

# Save results
torch.save(rogi_dict, f"results/rogi_scores_{args.randseed}.bin")