import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

from sklearn.metrics import r2_score
from scipy.stats import pearsonr, linregress


def plot_density(xrange, yrange, density1, density2, title1, title2, bin_num=101, range1=None, range2=None):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5.5))

    if range1 is None:
        range1 = (density1.min(), density1.max())
    if range2 is None:
        range2 = (density2.min(), density2.max())

    cont1 = ax1.contourf(xrange, yrange, density1, levels=np.linspace(*range1, bin_num),
                         cmap=mpl.colormaps['binary'])
    plt.colorbar(cont1, fraction=0.1, pad=0.01)
    ax1.set_title(title1)

    cont2 = ax2.contourf(xrange, yrange, density2, levels=np.linspace(*range2, bin_num),
                         cmap=mpl.colormaps['binary'])
    plt.colorbar(cont2, fraction=0.1, pad=0.01)
    ax2.set_title(title2)


def plot_thetas(true_ctr, pred_ctr, stderrs=None, title="Prediction plot", max_dev=0.1):
    r2 = r2_score(true_ctr, pred_ctr)

    plt.figure(figsize=(11, 6))

    xrange = np.linspace(0.1, 0.9)

    plt.scatter(true_ctr, pred_ctr, s=25, label=f"$R^2 = $ {r2:.3f}")
    if stderrs is not None:
        plt.errorbar(true_ctr, pred_ctr, yerr=stderrs, fmt="None", alpha=0.5)
    plt.fill_between(xrange, xrange - max_dev, xrange + max_dev, alpha=0.3)

    plt.plot([0, 1], [0, 1], "--", c='tab:purple')

    plt.ylim(-0.1, 1.1)
    plt.xlim(-0.1, 1.1)

    plt.xlabel('True CTR')
    plt.ylabel('Pred CTR')
    plt.title(title)
    plt.legend()


def err_corr(residual1, stderr1, residual2=None, stderr2=None, title1="", title2=""):
    if residual2 is None:
        corr = pearsonr(residual1, stderr1).statistic
        slope, intercept, *_ = linregress(residual1, stderr1)

        fig, ax1 = plt.subplots(1, 1, figsize=(7, 5))

        ax1.scatter(residual1, stderr1, s=25, label=f"Corr = {corr:.3f}", alpha=0.5)

        ax1.plot([0, 1], [intercept, slope + intercept], "--", c='tab:purple')

        ax1.set_xlim(-0.1, 1.1)

        ax1.set_xlabel('Absolute difference in predicted and actual CTR')
        ax1.set_ylabel(r'Beta $\sigma$')
        ax1.set_title(title1)
        ax1.legend()

    else:
        corr1 = pearsonr(residual1, stderr1).statistic
        corr2 = pearsonr(residual2, stderr2).statistic
        slope1, intercept1, *_ = linregress(residual1, stderr1)
        slope2, intercept2, *_ = linregress(residual2, stderr2)

        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

        ax1.scatter(residual1, stderr1, s=25, label=f"Corr = {corr1:.3f}", alpha=0.5)

        ax1.plot([0, 1], [intercept1, slope1 + intercept1], "--", c='tab:purple')

        ax1.set_xlim(-0.1, 1.1)

        ax1.set_xlabel('Absolute difference in predicted and actual CTR')
        ax1.set_ylabel(r'Beta $\sigma$')
        ax1.set_title(title1)
        ax1.legend()

        ax2.scatter(residual2, stderr2, s=25, label=f"Corr = {corr2:.3f}", alpha=0.5)

        ax2.plot([0, 1], [intercept2, slope2 + intercept2], "--", c='tab:purple')

        ax2.set_xlim(-0.1, 1.1)

        ax2.set_xlabel('Absolute difference in predicted and actual CTR')
        ax2.set_ylabel(r'Beta $\sigma$')
        ax2.set_title(title2)
        ax2.legend()


def dual_scatter(X, X_proj, y):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))

    ax1.scatter(X[y == 0, 0], X[y == 0, 1], label=r"$\neg C$", alpha=0.5)
    ax1.scatter(X[y == 1, 0], X[y == 1, 1], label=r"$C$", alpha=0.5)
    ax1.set_title("Original")
    ax1.legend()

    ax2.scatter(X_proj[y == 0, 0], X_proj[y == 0, 1], label=r"$\neg C$", alpha=0.5)
    ax2.scatter(X_proj[y == 1, 0], X_proj[y == 1, 1], label=r"$C$", alpha=0.5)
    ax2.set_title("Projected")
    ax2.legend()


def dual_scatter_color(X, color1, color2, title1, title2):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5.5))

    sct = ax1.scatter(*X.T, alpha=0.5, s=25, c=color1, cmap="cool")
    plt.colorbar(sct, fraction=0.1, pad=0.01)
    ax1.set_title(title1)

    sct = ax2.scatter(*X.T, alpha=0.5, s=25, c=color2, cmap="cool")
    plt.colorbar(sct, fraction=0.1, pad=0.01)
    ax2.set_title(title2)
