import json
import os
import random
import numpy as np
import torch
import matplotlib.pyplot as plt
import tqdm
from typing import Tuple, List, Dict, Any

from data.utils import set_seed
from data.structures import MaskMatrices, PackedMolGraph
from data.load_data import load_data, SupportedDatasets
from data.load_multi_data import load_multi_data, SupportedMultiDatasets
from net.conformation.comparers.Equivalent import HYPER_DIR

COLLECT_SET = [
    'd',
    'phi',
    'psi',
    'd2',
    'd3',
]


def collect(packed_mol_graph: PackedMolGraph, pos: torch.FloatTensor) -> Dict[str, List[float]]:
    ret_dict = {}

    mask_matrices = packed_mol_graph.mask_matrices
    vew1 = torch.cat([mask_matrices.vertex_edge_w1, mask_matrices.vertex_edge_w2], dim=1)
    vew2 = torch.cat([mask_matrices.vertex_edge_w2, mask_matrices.vertex_edge_w1], dim=1)
    
    n_v, n_e = vew1.shape
    distance_matrix = torch.norm(torch.unsqueeze(pos, 1) - torch.unsqueeze(pos, 0), dim=2)
    edge_lengths = torch.norm((vew1.t() - vew2.t()) @ pos,
                              dim=1)
    u_pos = vew1.t() @ pos
    v_pos = vew2.t() @ pos
    distance_matrix_uu = torch.norm(torch.unsqueeze(u_pos, 1) - torch.unsqueeze(u_pos, 0), dim=2)
    distance_matrix_uv = torch.norm(torch.unsqueeze(u_pos, 1) - torch.unsqueeze(v_pos, 0), dim=2)
    distance_matrix_vu = torch.norm(torch.unsqueeze(v_pos, 1) - torch.unsqueeze(u_pos, 0), dim=2)
    distance_matrix_vv = torch.norm(torch.unsqueeze(v_pos, 1) - torch.unsqueeze(v_pos, 0), dim=2)

    if 'd' in COLLECT_SET:
        adj_mask_1 = torch.tril(vew1 @ vew2.t(), diagonal=-1)
        d1_matrix = adj_mask_1 * distance_matrix
        ret_dict['d'] = (np.ravel(d1_matrix)[np.flatnonzero(d1_matrix)]).tolist()

    if 'd2' in COLLECT_SET:
        adj_mask_1 = vew1 @ vew2.t()
        adj_mask_2 = adj_mask_1 @ adj_mask_1.t()
        d2_matrix = adj_mask_2 * distance_matrix
        ret_dict['d2'] = (np.ravel(d2_matrix)[np.flatnonzero(d2_matrix)]).tolist()

    if 'd3' in COLLECT_SET:
        adj_mask_1 = vew1 @ vew2.t()
        adj_mask_3 = adj_mask_1 @ adj_mask_1.t() @ adj_mask_1.t()
        d3_matrix = adj_mask_3 * distance_matrix
        ret_dict['d3'] = (np.ravel(d3_matrix)[np.flatnonzero(d3_matrix)]).tolist()

    if 'phi' in COLLECT_SET:
        chain_mask_1 = torch.tril((vew2.t() @ vew1) * (- vew1.t() @ vew2 + 1), diagonal=-1)
        indices = np.flatnonzero(chain_mask_1)
        a = np.array(edge_lengths)[[int(index / n_e) for index in indices]]
        b = np.array(edge_lengths)[[int(index % n_e) for index in indices]]
        c = np.ravel(distance_matrix_uv)[indices]
        cos_phi = np.divide(a ** 2 + b ** 2 - c ** 2, 2 * a * b)
        phi = np.arccos(cos_phi.clip(-1 + 1e-6, 1 - 1e-6))
        ret_dict['phi'] = phi.tolist()

    if 'psi' in COLLECT_SET:
        chain_mask_1 = vew2.t() @ vew1
        chain_mask_2 = torch.tril((chain_mask_1 @ chain_mask_1) * (-(vew1 + vew2).t() @ (vew1 + vew2) + 1), diagonal=-1)
        indices_2 = np.flatnonzero(chain_mask_2)
        a = np.array(edge_lengths)[[int(index / n_e) for index in indices_2]]
        b = np.ravel(distance_matrix_vu)[indices_2]
        c = np.array(edge_lengths)[[int(index % n_e) for index in indices_2]]
        d = np.ravel(distance_matrix_uu)[indices_2]
        e = np.ravel(distance_matrix_vv)[indices_2]
        f = np.ravel(distance_matrix_uv)[indices_2]
        r1 = b ** 2 + c ** 2 - e ** 2
        r2 = b ** 2 - c ** 2 + e ** 2
        t1 = a ** 2 + b ** 2 - d ** 2
        t2 = a ** 2 + e ** 2 - f ** 2
        sin2_psi = np.divide(
            4 * a ** 2 * b ** 2 * e ** 2
            - b ** 2 * t2 ** 2
            - a ** 2 * r2 ** 2
            - e ** 2 * t1 ** 2
            + r2 * t1 * t2,
            4 * a ** 2 * b ** 2 * c ** 2 - a ** 2 * r1 ** 2 + 1e-6
        )
        sin_psi = np.sqrt(sin2_psi.clip(1e-6, 1 - 1e-6))
        psi = np.arcsin(sin_psi)
        ret_dict['psi'] = psi.tolist()

    return ret_dict


def plt_distribution(dist: List[float], name: str):
    mean = np.mean(dist)
    std = np.std(dist, ddof=1)
    if name.startswith('d'):
        print(f'{name}: mean = {mean:.3f}, std = {std:.3f}')
    else:
        print(f'{name}: mean = {mean * 180 / 3.1416:.3f}°({mean:.3f}), std = {std * 180 / 3.1416:.3f}°({std:.3f})')

    fig = plt.figure(figsize=(6, 4))
    ys = [random.random() for _ in dist]
    if name.startswith('d'):
        plt.xlim(0, 8)
        plt.scatter(dist, ys, s=1)
    else:
        if name == 'psi':
            plt.xlim(0, 90)
        else:
            plt.xlim(0, 180)
        plt.scatter([d * 180 / 3.1416 for d in dist], ys, s=1)
    plt.savefig(f'data/distribution/{name}.png')
    plt.close(fig)
    return mean, std


def degree_of_d_phi_psi_freedom(packed_mol_graph: PackedMolGraph) -> Tuple[int, int, int]:
    mask_matrices = packed_mol_graph.mask_matrices
    mvw = mask_matrices.mol_vertex_w
    vew1 = mask_matrices.vertex_edge_w1
    vew2 = mask_matrices.vertex_edge_w2
    n_m, n_v = mvw.shape
    n_e = vew1.shape[1]

    # d
    n_d = n_v - n_m

    # phi
    e_d = n_e - n_d
    n_neighbor = (vew1 + vew2).sum(dim=1)
    n_phi = int(sum((2 * n_neighbor - 3) * (n_neighbor > 1.5))) - 2 * e_d

    # psi
    n_psi = 3 * n_v - 6 * n_m - n_d - n_phi

    return n_d, n_phi, n_psi


def generate_conf_feature(dataset_name: str, n_mol_per_pack=20):
    distributions = {c: [] for c in COLLECT_SET}
    set_seed(0)

    if dataset_name in SupportedMultiDatasets.tolist():
        train_dataset, _, _ = load_multi_data(dataset_name, dataset_token='train_only', n_mol_per_pack=n_mol_per_pack,
                                              train_only=True)
    else:
        train_dataset, _, _, _ = load_data(dataset_name, dataset_token='train_only', n_mol_per_pack=n_mol_per_pack)
    iterator = tqdm.tqdm(
        iterable=train_dataset,
        total=len(train_dataset)
    )
    for pmg, _, _, geom, _, _ in iterator:
        for k, v in collect(pmg, geom).items():
            distributions[k].extend(v)

    if not os.path.isdir('data/distribution'):
        os.mkdir('data/distribution')
    dict_dist_cnt = {}
    dict_freedom_cnt = {'d': 0, 'phi': 0, 'psi': 0}

    dict_s = {}
    for k, v in distributions.items():
        m, s = plt_distribution(v, name=k)
        dict_s[k] = s
        dict_dist_cnt[k] = len(v)

    iterator = tqdm.tqdm(
        iterable=train_dataset,
        total=len(train_dataset)
    )
    for pmg, _, _, _, _, _ in iterator:
        df_d, df_phi, df_psi = degree_of_d_phi_psi_freedom(pmg)
        dict_freedom_cnt['d'] += df_d
        dict_freedom_cnt['phi'] += df_phi
        dict_freedom_cnt['psi'] += df_psi

    dict_f_p = {}
    for k in ['d', 'phi', 'psi']:
        print(f'# {k}: {dict_freedom_cnt[k]} / {dict_dist_cnt[k]} = {dict_freedom_cnt[k] / dict_dist_cnt[k]:.3f}')
        dict_f_p[k] = dict_freedom_cnt[k] / dict_dist_cnt[k]
    if not os.path.isdir(HYPER_DIR):
        os.mkdir(HYPER_DIR)
    with open(f'{HYPER_DIR}/{dataset_name}.json', 'w+') as fp:
        json.dump({
            'd': dict_f_p['d'] / dict_s['d'],
            'phi': dict_f_p['phi'] / dict_s['phi'],
            'psi': dict_f_p['psi'] / dict_s['psi'],
            'f_d': dict_f_p['d'],
            'f_phi': dict_f_p['phi'],
            'f_psi': dict_f_p['psi'],
            's_d': 1. / dict_s['d'],
            's_phi': 1. / dict_s['phi'],
            's_psi': 1. / dict_s['psi'],
        }, fp)


if __name__ == '__main__':
    generate_conf_feature(dataset_name=SupportedDatasets.QM9, n_mol_per_pack=20)
    generate_conf_feature(dataset_name=SupportedMultiDatasets.GEOM_QM9_SMALL, n_mol_per_pack=20)
    generate_conf_feature(dataset_name=SupportedMultiDatasets.GEOM_DRUGS_SMALL, n_mol_per_pack=5)

'''
QM9:
d: mean = 1.452, std = 0.103
phi: mean = 107.569°(1.877), std = 22.201°(0.387)
psi: mean = 30.427°(0.531), std = 24.824°(0.433)
d2: mean = 2.510, std = 0.645
d3: mean = 4.692, std = 2.192

# d: 830932 / 1002205 = 0.829
# phi: 1030264 / 1570148 = 0.656
# psi: 311779 / 1927138 = 0.162

GEOM_QM9:
d: mean = 1.442, std = 0.101
phi: mean = 108.005°(1.885), std = 21.765°(0.380)
psi: mean = 31.259°(0.546), std = 25.064°(0.437)
d2: mean = 2.493, std = 0.633
d3: mean = 4.642, std = 2.160
# d: 38959 / 46946 = 0.830
# phi: 48274 / 75635 = 0.638
# psi: 14644 / 94642 = 0.155

GEOM_DRUGS:
d: mean = 1.425, std = 0.109
phi: mean = 117.421°(2.049), std = 6.975°(0.122)
psi: mean = 17.032°(0.297), std = 22.599°(0.394)
d2: mean = 2.434, std = 0.127
d3: mean = 4.326, std = 1.389
# d: 119701 / 133582 = 0.896
# phi: 153801 / 209399 = 0.734
# psi: 70601 / 271308 = 0.260
'''
