import glob
import os
import sys

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from localreg import gaussian, localreg, epanechnikov
from scipy.interpolate import interp1d
from statsmodels.nonparametric.smoothers_lowess import lowess

from utils.misc_utils import TABLEAU10_RGB, uninterleave


def unzip(xs):
    a = None
    n = None
    for x in xs:
        if n is None:
            n = len(x)
            a = [[] for _ in range(n)]
        for i, y in enumerate(x):
            a[i].append(y)
    return a


def cross_validate_loess(x, y, nparts):
    parts = uninterleave(list(zip(x, y)), parts=nparts)

    df = pd.DataFrame(dict(deg=[], width=[], kernel=[]))

    for deg in [0, 1, 2]:
        for frac in [0.1, 0.2, 0.3, 0.4, 0.5]:
            for kernel in [gaussian, epanechnikov]:
                losses = []
                for i in range(nparts):
                    trainx, trainy = unzip(
                        sum([parts[j] for j in range(nparts) if j != i], [])
                    )
                    minx, maxx = min(trainx), max(trainx)
                    testx, testy = unzip(
                        [xy for xy in parts[i] if minx <= xy[0] <= maxx]
                    )

                    ypred = localreg(
                        np.array(trainx),
                        np.array(trainy),
                        degree=deg,
                        kernel=kernel,
                        frac=frac,
                    )

                    interpolator = interp1d(
                        trainx,
                        ypred,
                        bounds_error=False,
                        fill_value=(ypred[0], ypred[-1]),
                    )

                    losses.append(((testy - interpolator(testx)) ** 2).mean())

                df = df.append(
                    {
                        "deg": deg,
                        "frac": frac,
                        "kernel": kernel,
                        "loss": np.mean(losses),
                    },
                    ignore_index=True,
                )

    best_row = df.iloc[df["loss"].argmin()]
    ypred = localreg(
        np.array(x),
        np.array(y),
        degree=best_row["deg"],
        kernel=best_row["kernel"],
        frac=best_row["frac"],
    )
    interpolator = interp1d(
        x, ypred, bounds_error=False, fill_value=(ypred[0], ypred[-1]),
    )
    return interpolator, df


if __name__ == "__main__":
    file_dir = "./lr_optimization/"

    if len(sys.argv) > 1:
        file_paths = [os.path.join(file_dir, "{}__2_15.tsv".format(sys.argv[1]))]
    else:
        file_paths = glob.glob(os.path.join(file_dir, "*.tsv"))

    for file_path in file_paths:
        df = pd.read_csv(file_path, sep="\t")

        df = df.sort_values(by=["lr", "view_radius", "seed"])

        unique_colors = TABLEAU10_RGB + [(0, 0, 0)]
        view_radii = np.array(df["view_radius"], dtype=int)
        unique_view_radii = sorted(list(set(view_radii)))

        view_rad_to_lrs = {}
        view_rad_to_ep_lengths = {}
        view_rad_to_expert_ces = {}

        for i, vr in enumerate(unique_view_radii):
            subdf = df.query("view_radius == {}".format(vr))
            view_rad_to_lrs[vr] = np.array(subdf["lr"])
            view_rad_to_ep_lengths[vr] = np.array(subdf["avg_ep_length"])
            view_rad_to_expert_ces[vr] = np.maximum(
                -np.array(subdf["expert_ce"]),
                -df.loc[np.round(df["expert_ce"], 12) != 0.0, "expert_ce"].max(),
            )

        view_rad_to_ep_lengths_preds = {}
        for i, vr in enumerate(unique_view_radii):
            lrs = view_rad_to_lrs[vr]
            avg_ep_lengths = view_rad_to_ep_lengths[vr]

            plt.scatter(lrs, avg_ep_lengths, color=np.array(unique_colors[i]) / 255)
            view_rad_to_ep_lengths_preds[vr] = lowess(
                avg_ep_lengths, np.log(lrs), frac=0.3
            )[:, 1]
            plt.plot(
                lrs,
                view_rad_to_ep_lengths_preds[vr],
                color=np.array(unique_colors[i]) / 255,
            )

        plt.xscale("log")
        plt.show()

        view_rad_to_ep_expert_ce_preds = {}
        for i, vr in enumerate(unique_view_radii):
            lrs = view_rad_to_lrs[vr]
            expert_ces = view_rad_to_expert_ces[vr]
            view_rad_to_ep_expert_ce_preds[vr] = np.exp(
                lowess(np.log(expert_ces), np.log(lrs), frac=0.3)[:, 1]
            )

            if vr != max(unique_view_radii):
                plt.scatter(lrs, expert_ces, color=np.array(unique_colors[i]) / 255)
                plt.plot(
                    lrs,
                    view_rad_to_ep_expert_ce_preds[vr],
                    color=np.array(unique_colors[i]) / 255,
                )

        plt.xlabel("LR")
        plt.ylabel("Cross Entropy")
        plt.xscale("log")
        plt.yscale("log")
        plt.show()

        best_ep_len_lrs = [
            lrs[np.argmin(view_rad_to_ep_lengths_preds[vr])]
            for vr in unique_view_radii
            if vr != max(unique_view_radii)
        ]
        print("Bests ep_len lrs")
        print(best_ep_len_lrs)

        best_lrs = [
            lrs[np.argmin(view_rad_to_ep_expert_ce_preds[vr])]
            for vr in unique_view_radii
            if vr != max(unique_view_radii)
        ]
        print("Bests CE lrs")
        print(best_lrs)

        print("Mean lr {}".format(np.mean(best_lrs)))
