from math import pi

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd


def make_radar_plot():
    input_path = "visualizations/raw_data/molecular_filters.csv"
    output_path = "visualizations/plots/molecular_filters.png"

    df = pd.read_csv(input_path, index_col=0)
    druglike = [
        "FAF4Druglike",
        "FAF4Leadlike",
        "Ghose",
        "GSK",
        "Lipinski",
        "Oprea",
        "Veber",
        "Xu",
        "ZINCDruglike",
    ]
    other = ["BeyondRo5", "Brenk", "Glaxo", "Pfizer", "REOS", "ZINCBasic", "Hao"]

    datasets = ["MolPILE", "ChEMBL", "GDB-17", "ZINC"]

    fig = plt.figure(figsize=(12, 6))

    def make_radar_plot(ax, data, colors, legend=False):
        num_vars = len(data.index)

        angles = [n / float(num_vars) * 2 * pi for n in range(num_vars)]
        angles += angles[:1]

        ax.set_theta_offset(pi / 2)
        ax.set_theta_direction(-1)

        ax.set_rlabel_position(0)
        ax.set_xticks(angles[:-1])

        # position labels nicely far from the center
        ax.set_xticklabels([])  # remove default labels
        ax.tick_params(axis="x", pad=0)

        for angle, label in zip(angles[:-1], data.index, strict=False):
            x = angle
            y = 110  # base radius is 100
            if label in ["Ghose"]:
                y = 112
            if label in ["ZINCDruglike", "FAF4Leadlike"]:
                y = 115
            if label in ["ZINCBasic"]:
                y = 120
            ax.text(
                x,
                y,
                label,
                size=10,
                horizontalalignment="center",
                verticalalignment="center",
            )

        ax.set_yticks([20, 40, 60, 80, 100])
        ax.set_yticklabels(["20", "40", "60", "80", "100"], color="grey", size=8)
        ax.set_ylim(0, 100)

        for idx, db in enumerate(datasets):
            values = data[db].astype(float).tolist()
            values += values[:1]
            ax.plot(
                angles,
                values,
                linewidth=2,
                linestyle="solid",
                label=db,
                color=colors[idx],
            )
            ax.fill(angles, values, color=colors[idx], alpha=0.2)

        if legend:
            ax.legend(loc="upper right", bbox_to_anchor=(1.3, 1.1), fontsize=8)

    colors = plt.cm.viridis(np.linspace(0, 1, len(datasets)))

    ax1 = fig.add_subplot(121, polar=True)
    ax2 = fig.add_subplot(122, polar=True)

    make_radar_plot(ax1, df.loc[druglike], colors, legend=True)
    make_radar_plot(ax2, df.loc[other], colors)

    ax1.set_title("Druglike Metrics", size=14, y=1.1)
    ax2.set_title("Other Metrics", size=14, y=1.1)
    plt.tight_layout()

    plt.savefig(output_path, dpi=300, bbox_inches="tight")


if __name__ == "__main__":
    make_radar_plot()
