import os
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

from tqdm import tqdm

from data.load_data import SupportedDatasets
from net.utils.components import GraphIsomorphismNetwork
from train.recover_conf import recover_datasets, recover_conf


def eval_subspace_remapping():
    derive = 'langevin'
    compare = 'equiv-trunc'
    special_config = {'DERIVE_TYPE': derive, 'COMPARE_TYPE': compare}
    dataset_name = SupportedDatasets.QM7
    dataset_token = 'phi-psi'
    token = f'lstm-{derive}-{compare}'
    use_tqdm = True
    train_set, validate_set, test_set = recover_datasets(
        special_config=special_config,
        dataset_name=dataset_name,
        dataset_token=dataset_token,
        seed=0,
        force_save=False,
        use_cuda=False
    )
    model = recover_conf(
        special_config=special_config,
        dataset_name=dataset_name,
        token=token,
        use_cuda=False
    )

    list_q_dis = []
    list_pos_dis = []
    list_geom_dis = []
    fig = plt.figure(figsize=(8, 8))
    if use_tqdm:
        iteration = tqdm(train_set, total=len(train_set))
    else:
        iteration = train_set
    for packed_mol_graphs, smiles_set, target, dft_geometry, rdkit_geometry, extra_dict in iteration:
        return_dict = model.get_intermediate_dict(
            atom_ftr=packed_mol_graphs.atom_ftr,
            bond_ftr=packed_mol_graphs.bond_ftr,
            mask_matrices=packed_mol_graphs.mask_matrices,
            target_pos_ftr=dft_geometry,
            rdkit_pos_ftr=rdkit_geometry,
            smiles_set=smiles_set,
            extra_dict=extra_dict,
            return_list=['derive']
        )
        q = return_dict['list_q_ftr'][-1]
        pos = return_dict['pos_ftr']
        vew1, vew2 = GraphIsomorphismNetwork.extend_graph_no_dis(
            vew1=packed_mol_graphs.mask_matrices.vertex_edge_w1,
            vew2=packed_mol_graphs.mask_matrices.vertex_edge_w2,
            use_cuda=False
        )
        q_dis = GraphIsomorphismNetwork.edge_distances(q, vew1, vew2, keepdim=False).detach().cpu().tolist()
        pos_dis = GraphIsomorphismNetwork.edge_distances(pos, vew1, vew2, keepdim=False).detach().cpu().tolist()
        geom_dis = GraphIsomorphismNetwork.edge_distances(dft_geometry, vew1, vew2, keepdim=False
                                                          ).detach().cpu().tolist()
        list_q_dis.extend(q_dis)
        list_pos_dis.extend(pos_dis)
        list_geom_dis.extend(geom_dis)

    # plt.scatter(x=list_q_dis, y=list_pos_dis, s=0.1, c='blue')
    print(len(list_pos_dis))
    sns.regplot(x=list_q_dis[::100], y=list_pos_dis[::100])

    if not os.path.isdir('train/figure'):
        os.mkdir('train/figure')
    if not os.path.isdir('train/figure/subspace_remapping'):
        os.mkdir('train/figure/subspace_remapping')
    plt.savefig(f'train/figure/subspace_remapping/{dataset_name}-{token}.png')
    plt.close(fig=fig)

    list_q_dis = pd.Series(list_q_dis)
    list_pos_dis = pd.Series(list_pos_dis)
    list_geom_dis = pd.Series(list_geom_dis)
    corr1 = pd.Series.corr(list_q_dis, list_pos_dis)
    corr2 = pd.Series.corr(list_q_dis, list_geom_dis)
    corr3 = pd.Series.corr(list_pos_dis, list_geom_dis)
    print(f'corr q pos: {corr1}')
    print(f'corr q geom: {corr2}')
    print(f'corr pos geom: {corr3}')
