import torch
import numpy as np
import pandas as pd
import random
from omegaconf import OmegaConf
import pickle, functools, copy
from itertools import product
import math
import ray

from gflownet.MDPs import molstrmdp
from gflownet.GFNs import models
from sehstr import SEH_il
from datasets.sehstr import gbr_proxy
from gflownet.data import Experience

from rdkit import Chem
from rdkit.Chem.rdMolDescriptors import GetMorganFingerprintAsBitVect
from rdkit.DataStructs import FingerprintSimilarity


def check_novelty(test, train_datasets):
    cnt = sum(1 for x in test if x not in train_datasets)
    return cnt/len(test)

def check_novelty_with_diversity(test, train_datasets):
    div_matrix, div_value = check_diversity_seh(train_datasets, test)
    return div_value

def check_unique(test):
    key = list(test.keys())
    n_uniq = len(set(key))
    return n_uniq/len(test)

def check_diversity_seh(set1_dict, set2_dict):
    def dist_states(state1, state2):
        """ Tanimoto similarity on morgan fingerprints """
        fp1 = mdp.get_morgan_fp(state1)
        fp2 = mdp.get_morgan_fp(state2)
        return 1 - FingerprintSimilarity(fp1, fp2)

    @functools.cache
    def get_morgan_fp(state):
        mol = mdp.state_to_mol(state)
        fp = GetMorganFingerprintAsBitVect(mol, 2, nBits=1024)
        return fp

    set1 = list(set1_dict.keys())
    set2 = list(set2_dict.keys())
    n1, n2 = len(set1), len(set2)
    D = np.empty((n1, n2), dtype=float)
    for i, j in product(range(n1), range(n2)):
        D[i, j] = dist_states(mdp.state(set1[i], True), mdp.state(set2[j], True))
    return D, np.mean(D)

def true_reward(test):
    policy, true = zip(*[(p, t) for p, t in test.values()])
    policy = np.asarray(policy)
    true = np.asarray(true)
    k = int(len(test)/10)
    policy_idx_top = np.argpartition(policy, -k)[-k:]
    true_idx_top = np.argpartition(true, -k)[-k:]
    policy_rmean = np.mean(policy)
    true_rmean = np.mean(true)
    policy_top_trueR = np.mean(true[policy_idx_top])
    policy_rtop = np.mean(policy[policy_idx_top])
    true_rtop = np.mean(true[true_idx_top])
    print('policy reward mean: ', policy_rmean, ' / policy reward top10% mean: ', policy_rtop)
    print('true reward mean: ', true_rmean, ' / true reward top10% mean: ', true_rtop)
    print('true reward of top 10% policy: ', policy_top_trueR)

def true_reward_final(test):
    policy, true = zip(*[(p, t) for p, t in test.values()])
    policy = np.asarray(policy)
    true = np.asarray(true)
    true_rtop = []
    policy_top_trueR = []
    for k in [10, 50, 100, 500]:
        policy_idx_top = np.argpartition(policy, -k)[-k:]
        true_idx_top = np.argpartition(true, -k)[-k:]
        true_rtop.append(np.mean(true[true_idx_top]))
        policy_top_trueR.append(np.mean(true[policy_idx_top]))        
    return true_rtop, policy_top_trueR

def true_reward_mean(test):
    policy, true = zip(*[(p, t) for p, t in test.values()])
    true = np.asarray(true)
    return np.mean(true)

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministi=True
        torch.backends.cudnn.benchmark = False
    
    
if __name__ == "__main__":
    args = OmegaConf.load(f"./settings/seh_1000_proxy_2_0.yaml")
    mdp = SEH_il(args)
    with open('datasets/sehstr/block_18_stop6.pkl', 'rb') as f:
        all_datasets = pickle.load(f)
    keys = [mdp.state(k, True) if not hasattr(k, "content") else k for k in all_datasets.keys()]
    values = np.array(list(all_datasets.values()))
    true_datasets = dict(zip(keys, values))

    version = ['seh_1000_proxy_2_0', 'seh_1000_proxy_2_0.5']
    number = ['', '_1', '_2']
    rows = []
    metric_cols = [
        'nov', 'nov_div', 'uniq', 'div',
        'rtop_10', 'rtop_50', 'rtop_100', 'rtop_500', 'true_mean'
    ]

    for ver in version:
        set_seed(42)
        args = OmegaConf.load(f"./settings/{ver}.yaml")
        args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        with open(args.offline_dataset, 'rb') as f:
            train_datasets = pickle.load(f)     
        mdp = SEH_il(args)
        actor = molstrmdp.MolStrActor(args, mdp)
        model = models.make_model(args, mdp, actor)
        for num in number:
            latest_model = f'saved_models/{ver}{num}/{ver}{num}_round_25000.pth'
            model.load_params(latest_model)
            test = model.batch_fwd_sample(5000, epsilon=0)
            test_trueR = [
                exp._replace(logp_guide = true_datasets[exp.x]) for exp in test
            ]
            test_dict = {}
            for exp in test_trueR:
                test_dict[exp.x.content]=(exp.r, exp.logp_guide)
    
            nov = check_novelty(test_dict, train_datasets)
            nov_div = check_novelty_with_diversity(test_dict, train_datasets)
            uniq = check_unique(test_dict)
            div_matrix, div_value = check_diversity_seh(test_dict, test_dict)
            true_rtop, policy_top_trueR = true_reward_final(test_dict)
            rtop_10, rtop_50, rtop_100, rtop_500 = true_rtop
            trues_mean = true_reward_mean(test_dict)

            result = {
                'model': ver,
                'version': '_0' if num=='' else num,
                'nov': nov,
                'nov_div': nov_div,
                'uniq': uniq,
                'div': div_value,
                'rtop_10': rtop_10, 'rtop_50': rtop_50, 'rtop_100': rtop_100, 'rtop_500': rtop_500,
                'true_mean': trues_mean
            }
            rows.append(result)
    df = pd.DataFrame(rows)
    avg_df = (df.groupby('model', as_index=False)[metric_cols].mean().assign(version='avg'))
    std_df = (df.groupby('model', as_index=False)[metric_cols].std().assign(version='std'))
    final_df = pd.concat([df, avg_df, std_df], ignore_index=True)
    print(final_df.sort_values(['model', 'version']))
    final_df.to_csv('seh_3round_results.csv', index=False)