import os
from collections import defaultdict

import hydra
from omegaconf import DictConfig, OmegaConf
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import TimeSeriesSplit
from scipy.optimize import curve_fit
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import statsmodels.api as sm

EXP_DIR = 'paper/llama_iso_flop'

class PowerLawPlotter:
    def __init__(self, plot_type, plot_cfg):
        self.flop_to_run_path = {
            '1e14': [],
            '2e14': [],
            '5e14': [],
            '1e15': [],
            '2e15': [],
            '5e15': [],
            '1e16': [],
            '2e16': [],
            '5e16': [],
            '1e17': [],
            '2e17': [],
            '5e17': [],
            '1e18': [],
            '2e18': [],
            '3e18': [],
            '4e18': [],
            '5e18': []
        }

        self.plot_type = plot_type
        self.plot_cfg = plot_cfg
        self.result_table, self.flop_params_to_samples = self._build_result_table(plot_type)

        self.xticks = plot_cfg.xticks
        self.xlabels = plot_cfg.xlabels

        self.yticks = plot_cfg.yticks
        self.ylabel = plot_cfg.ylabel
        self.ylabels = plot_cfg.ylabels
        self.ylim_min = plot_cfg.ylim_min
        self.ylim_max = plot_cfg.ylim_max
        self.legend_loc = plot_cfg.legend_loc

        self.tick_font_size = "17"
        self.legend_font_size = "18"
        self.title_font_size = "21"
        self.label_font_size = "20"

        self.model_ylabels = plot_cfg.model_ylabels
        self.model_yticks = plot_cfg.model_yticks
        self.model_ylim_min = plot_cfg.model_ylim_min
        self.model_ylim_max = plot_cfg.model_ylim_max
        self.model_legend_loc = plot_cfg.model_legend_loc

        self.samples_ylabels = plot_cfg.samples_ylabels
        self.samples_yticks = plot_cfg.samples_yticks
        self.samples_ylim_min = plot_cfg.samples_ylim_min
        self.samples_ylim_max = plot_cfg.samples_ylim_max
        self.samples_legend_loc = plot_cfg.samples_legend_loc

        if self.plot_type != "loss":
            self.corr_yticks = plot_cfg.corr_yticks
            self.corr_xticks = plot_cfg.corr_xticks
            self.corr_xlabels = plot_cfg.corr_xlabels
            self.corr_ylim_min = plot_cfg.corr_ylim_min
            self.corr_ylim_max = plot_cfg.corr_ylim_max
            self.corr_ylabels = plot_cfg.corr_ylabels

        self.color = plot_cfg.color

        for flop in plot_cfg.ignore_flops:
            del self.result_table[flop]

        self.inference_flops = plot_cfg.inference_flops

    def _build_result_table(self, plot_type: str):
        result_table = defaultdict(dict)
        flop_params_to_samples = defaultdict(dict)
        
        if plot_type == 'loss':
            run_ids = os.listdir(EXP_DIR)
            for run_id in run_ids:
                run_dir = os.path.join(EXP_DIR, run_id)
                run_files = os.listdir(run_dir)
                for run_file in run_files:
                    for flop in self.flop_to_run_path.keys():
                        if run_file.endswith(f'{flop}.tar'):
                            run_path = os.path.join(run_dir, run_file)
                            data = torch.load(run_path)
                            result_table[flop][data['params']] = data['eval_loss']
                            flop_params_to_samples[flop][data['params']] = data['samples']
                            break
        else:
            run_ids = os.listdir(EXP_DIR)
            for run_id in run_ids:
                run_dir = os.path.join(EXP_DIR, run_id)
                run_files = os.listdir(run_dir)
                for run_file in run_files:
                    for flop in self.flop_to_run_path.keys():
                        if run_file.endswith(f'{flop}.tar'):
                            run_path = os.path.join(run_dir, run_file)
                            data = torch.load(run_path)
                            flop = run_path.split('/')[-1].split('.')[0][-4:]
                            return_path_list = run_path.split('/')[:-1] + [f'eval_returns_model_{flop}.txt']
                            if os.path.exists(os.path.join(*return_path_list)):
                                with open(os.path.join(*return_path_list), 'r') as f:
                                    return_data = f.readlines()
                                    returns = [float(x.strip()) for x in return_data]
                            if data['params'] in result_table[flop]:
                                result_table[flop][data['params']].extend(returns)
                            else:
                                result_table[flop][data['params']] = returns
                            flop_params_to_samples[flop][data['params']] = data['samples']
                            break
            
            # take mean
            for flop in result_table.keys():
                for param in result_table[flop].keys():
                    result_table[flop][param] = np.mean(result_table[flop][param])

        return result_table, flop_params_to_samples

    def plot(self, cross_validate: bool = False, predict: bool = False):
        
        # use all data
        train_xs = [flop for flop in self.result_table.keys() if flop not in self.inference_flops]
        xs = sorted(train_xs, key=lambda x: float(x))
        test_xs = [flop for flop in self.result_table.keys() if flop in self.inference_flops]
        test_xs = sorted(test_xs, key=lambda x: float(x))
        if self.plot_type == "loss":
            ys = np.array([np.min(list(self.result_table[x].values())) for x in xs])
            test_ys = np.array([np.min(list(self.result_table[x].values())) for x in test_xs])
        else:
            ys = np.array([np.max(list(self.result_table[x].values())) for x in xs])
            test_ys = np.array([np.max(list(self.result_table[x].values())) for x in test_xs])
        xs = np.array(list(map(lambda x: float(x), xs)))
        test_xs = np.array(list(map(lambda x: float(x), test_xs)))

        # fit log-linear regression
        X, y = np.expand_dims(np.log(xs), 1), np.log(np.array(ys))
        lr = sm.OLS(y, sm.add_constant(X, has_constant="add")).fit()

        if cross_validate:
            avg_rmse, beta_0s, beta_1s, num_in_sample, avg_pi = self._cross_validate(xs, ys, max_clip=(self.plot_cfg.expert_score if self.plot_type == "return" else None))
            print(f'Avg. RMSE: {avg_rmse:.3f}')
            print(f'Avg. PI: {avg_pi:.3f}')

        beta_0_ci = lr.conf_int()[0]
        beta_1_ci = lr.conf_int()[1]
        print(f"Beta 0 CI: {beta_0_ci}")
        print(f"Beta 1 CI: {beta_1_ci}")
        print(f"Beta 0: {lr.params[0]}")
        print(f"Beta 1: {lr.params[1]}")

        plt.style.use("seaborn")
        sns.set_style("whitegrid")
        fig, ax = plt.subplots()

        plt.ylim(self.ylim_min, self.ylim_max)

        plt.xscale("log")
        plt.yscale("log")

        plt.xlabel("FLOPs", fontsize=self.label_font_size)
        plt.ylabel(self.ylabel, fontsize=self.label_font_size)

        # plt.scatter(xs, ys, s=80, color=self.color, label="Train")#("Train" if len(self.inference_flops) > 0 else None))
        plt.scatter(xs, ys, s=80, color=self.color, label=("Train" if len(self.inference_flops) > 0 else None))
        if len(self.inference_flops) > 0:
            plt.scatter(test_xs, test_ys, s=80, color='green', label="Test")

        # plt.scatter([7.709752e+18, 1.477708e+19], [7e3, 7.784e3], s=80, color='green', label="Test")
        # plt.scatter([3.327295e+19], [0.126748725771904], s=80, color='green', label="Test")

        if self.plot_type == "return":
            # plot expert score
            print('plotting expert line')
            plt.axhline(
                y=self.plot_cfg.expert_score,
                color="black",
                linestyle="--",
                label="Expert",
                alpha=1,
                linewidth=2.2,
            )

        plt.yticks(ticks=self.yticks, labels=self.ylabels, fontsize=self.tick_font_size)
        plt.xticks(ticks=self.xticks, labels=self.xlabels, fontsize=self.tick_font_size)
        ax.minorticks_off()
        # plt.title(self.game, fontsize=self.title_font_size)

        # plot regression line on log plot
        if self.plot_type == "return":
            # label = f"$\log L = {lr.params[0]:.2f} {lr.params[1]:.2f} \cdot \log C$"
            # else:
            label = f"$\log R = {lr.params[0]:.2f} + {lr.params[1]:.2f} \cdot \log C$"
            # self._plot_log_line(plt, lr, label=label, color=self.color, predict=predict)
        
        if self.plot_type == "loss":
            # curve fit
            def power_law_plus_constant(x, A, B, C):
                return A * x**B + C
                # return A / (1 + B * x**C)

            initial_guesses = [0, 0, 0.08]
            params, covariance = curve_fit(power_law_plus_constant, xs, ys, maxfev=10000, p0=initial_guesses)

            # Extract the fitted parameters
            A, B, C = params

            print(f"Fitted parameters: A = {A}, B = {B}, C = {C}")

            # Predict y values using the fitted parameters
            full_xs = np.concatenate((xs, test_xs, [1e19, 1.5e19, 2e19, 2.5e19, 3e19, 3.327295e+19])) # [1e19, 2e19, 3.327295e+19]
            y_fit = power_law_plus_constant(full_xs, *params)
            plt.plot(full_xs, y_fit, "--", label=f'$L = {A:.2f} \cdot C^{{{B:.2f}}} + {C:.2f}$', linewidth=2.2, color=sns.color_palette()[0])
            # plt.plot(full_xs, y_fit, "--", linewidth=2.2, color=sns.color_palette()[0])

        else:
            # A = None
            # B = None
            # C = None
            # curve fit
            def power_law_plus_constant(x, A, B, C):
                # return A * x**B + C
                return A / (1 + B * x**C)

            # initial_guesses = [0, 0, 17_000]
            initial_guesses = [10_000, 5000, -0.06]
            params, covariance = curve_fit(power_law_plus_constant, xs, ys, maxfev=100000, p0=initial_guesses)

            # Extract the fitted parameters
            A, B, C = params

            print(f"Fitted parameters: A = {A}, B = {B}, C = {C}")

            # Predict y values using the fitted parameters
            # full_xs = np.concatenate((xs, test_xs, [7.709752e+18, 1.477708e+19, 2e19, 3e19])) #[7.709752e+18, 1.477708e+19]
            full_xs = np.concatenate((xs, test_xs))
            y_fit = power_law_plus_constant(full_xs, *params)

            b_over_a = f"{B / A:.2e}"
            print("B / A", b_over_a)
            if '-' in b_over_a:
                base = b_over_a.split('-')[0]
                exp = f"-{int(b_over_a.split('-')[1])}"
            else:
                base = b_over_a.split('+')[0]
                exp = f"{int(b_over_a.split('+')[1])}"
            b_over_a = f"{base}^{{{exp}}}"

            one_over_a = f"{1/A:.2e}"
            print("1 / A", one_over_a)
            base = one_over_a.split('-')[0]
            if '-' in one_over_a:
                exp = f"-{int(one_over_a.split('-')[1])}"
            else:
                exp = f"{int(one_over_a.split('+')[1])}"
            one_over_a = f"{base}^{{{exp}}}"

            plt.plot(full_xs, y_fit, "--", label=f'$R = ({b_over_a} \cdot C^{{{C:.2f}}} + {one_over_a})^{{{-1}}}$', linewidth=2.2, color=self.color)
            # plt.plot(full_xs, y_fit, "--", linewidth=2.2, color=self.color)

        plt.legend(fontsize=self.legend_font_size, frameon=True, loc=self.legend_loc)

        plt.savefig(f"paper/figures/{self.plot_type}_vs_flops_scaling_law.pdf")
        plt.close()

        return (lr.params[0], lr.params[1]), (A, B, C)

    def plot_model(self, cross_validate: bool = False, predict: bool = False):
        # use all data
        xs = sorted(self.result_table.keys(), key=lambda x: float(x))
        if self.plot_type == "loss":
            ys = []
            for flop in xs:
                min_loss = 1e9
                min_param = None
                for param in self.result_table[flop]:
                    if self.result_table[flop][param] < min_loss:
                        min_loss = self.result_table[flop][param]
                        min_param = param
                ys.append(min_param)
        else:
            ys = []
            for flop in xs:
                max_return = -1e9
                max_param = None
                for param in self.result_table[flop]:
                    if self.result_table[flop][param] > max_return:
                        max_return = self.result_table[flop][param]
                        max_param = param
                ys.append(max_param)

        ys = np.array(ys)
        xs = np.array(list(map(lambda x: float(x), xs)))

        # fit log-linear regression
        X, y = np.expand_dims(np.log(xs), 1), np.log(np.array(ys))
        lr = sm.OLS(y, sm.add_constant(X, has_constant="add")).fit()

        if cross_validate:
            avg_rmse, beta_0s, beta_1s, num_in_sample, avg_pi = self._cross_validate(xs, ys)
            print(f'Avg. RMSE: {avg_rmse:.3f}')
            print(f'Avg. PI: {avg_pi:.3f}')

        beta_0_ci = lr.conf_int()[0]
        beta_1_ci = lr.conf_int()[1]
        print(f"Beta 0 CI: {beta_0_ci}")
        print(f"Beta 1 CI: {beta_1_ci}")
        print(f"Beta 0: {lr.params[0]}")
        print(f"Beta 1: {lr.params[1]}")

        plt.style.use("seaborn")
        sns.set_style("whitegrid")
        fig, ax = plt.subplots()

        plt.ylim(self.model_ylim_min, self.model_ylim_max)

        plt.xscale("log")
        plt.yscale("log")

        plt.xlabel("FLOPs", fontsize=self.label_font_size)
        plt.ylabel("Parameters", fontsize=self.label_font_size)

        # plt.title(self.game, fontsize=self.title_font_size)

        plt.scatter(xs, ys, s=80, color=self.color)
        plt.yticks(
            ticks=self.model_yticks,
            labels=self.model_ylabels,
            fontsize=self.tick_font_size,
        )
        plt.xticks(ticks=self.xticks, labels=self.xlabels, fontsize=self.tick_font_size)
        ax.minorticks_off()

        # plot regression line on log plot
        label = f"$\log N = {lr.params[0]:.2f} + {lr.params[1]:.2f} \cdot \log C$"
        self._plot_log_line(plt, lr, label=label, color=self.color)

        plt.legend(
            fontsize=self.legend_font_size, frameon=True, loc=self.model_legend_loc
        )

        plt.savefig(
            f"paper/figures/{self.plot_type}_parameters_vs_flops_scaling_law.pdf"
        )
        plt.close()

        return (lr.params[0], lr.params[1])

    def plot_samples(self, cross_validate: bool = False, predict: bool = False):
        # use all data
        xs = sorted(self.result_table.keys(), key=lambda x: float(x))
        if self.plot_type == "loss":
            ys = []
            for flop in xs:
                min_loss = 1e9
                min_params = None
                for param in self.result_table[flop]:
                    if self.result_table[flop][param] < min_loss:
                        min_loss = self.result_table[flop][param]
                        min_params = param
                ys.append(self.flop_params_to_samples[flop][min_params])
        else:
            ys = []
            for flop in xs:
                max_return = -1e9
                max_params = None
                for param in self.result_table[flop]:
                    if self.result_table[flop][param] > max_return:
                        max_return = self.result_table[flop][param]
                        max_params = param
                ys.append(self.flop_params_to_samples[flop][max_params])

        ys = np.array(ys)
        xs = np.array(list(map(lambda x: float(x), xs)))

        # fit log-linear regression
        X, y = np.expand_dims(np.log(xs), 1), np.log(np.array(ys))
        lr = sm.OLS(y, sm.add_constant(X, has_constant="add")).fit()

        if cross_validate:
            avg_rmse, beta_0s, beta_1s, num_in_sample, avg_pi = self._cross_validate(xs, ys)
            print(f'Avg. RMSE: {avg_rmse:.3f}')
            print(f'Avg. PI: {avg_pi:.3f}')

        beta_0_ci = lr.conf_int()[0]
        beta_1_ci = lr.conf_int()[1]
        print(f"Beta 0 CI: {beta_0_ci}")
        print(f"Beta 1 CI: {beta_1_ci}")
        print(f"Beta 0: {lr.params[0]}")
        print(f"Beta 1: {lr.params[1]}")

        plt.style.use("seaborn")
        sns.set_style("whitegrid")
        fig, ax = plt.subplots()

        plt.xscale("log")
        plt.yscale("log")

        # plt.title(self.game, fontsize=self.title_font_size)

        plt.ylim(self.samples_ylim_min, self.samples_ylim_max)

        plt.xlabel("FLOPs", fontsize=self.label_font_size)
        plt.ylabel("Samples", fontsize=self.label_font_size)

        plt.scatter(xs, ys, s=80, color=self.color)
        plt.yticks(
            ticks=self.samples_yticks,
            labels=self.samples_ylabels,
            fontsize=self.tick_font_size,
        )
        plt.xticks(ticks=self.xticks, labels=self.xlabels, fontsize=self.tick_font_size)
        ax.minorticks_off()

        # plot regression line on log plot
        label = f"$\log D = {lr.params[0]:.2f} + {lr.params[1]:.2f} \cdot \log C$"
        self._plot_log_line(plt, lr, label=label, color=self.color)

        plt.legend(
            fontsize=self.legend_font_size, loc=self.samples_legend_loc, frameon=True
        )

        plt.savefig(
            f"paper/figures/{self.plot_type}_samples_vs_flops_scaling_law.pdf"
        )
        plt.close()

        return (lr.params[0], lr.params[1])

    def plot_correlation(self, cross_validate: bool = False):
        if self.plot_type == "loss":
            return

        self.loss_table, _ = self._build_result_table("loss")
        self.return_table, _ = self._build_result_table("return")

        for flop in self.plot_cfg.ignore_flops:
            del self.loss_table[flop]
            del self.return_table[flop]

        # use all data
        flops = sorted(self.loss_table.keys(), key=lambda x: float(x))
        xs = []
        ys = []
        for flop in flops:
            min_loss = 1e9
            min_param = None
            for param in self.loss_table[flop]:
                if self.loss_table[flop][param] < min_loss:
                    min_loss = self.loss_table[flop][param]
                    min_param = param
            ys.append(self.return_table[flop][min_param])
            xs.append(1 / min_loss)

        xs = np.array(list(map(lambda x: float(x), xs)))

        # fit log-linear regression
        X, y = np.expand_dims(np.log(xs), 1), np.log(np.array(ys))
        lr = sm.OLS(y, sm.add_constant(X, has_constant="add")).fit()

        # if cross_validate:
        #     avg_rmse, beta_0s, beta_1s, num_in_sample, avg_pi = self._cross_validate(xs, ys)
        #     print(f'Avg. RMSE: {avg_rmse:.3f}')
        #     print(f'Avg. PI: {avg_pi:.3f}')

        beta_0_ci = lr.conf_int()[0]
        beta_1_ci = lr.conf_int()[1]
        print(f"Beta 0 CI: {beta_0_ci}")
        print(f"Beta 1 CI: {beta_1_ci}")
        print(f"Beta 0: {lr.params[0]}")
        print(f"Beta 1: {lr.params[1]}")

        plt.style.use("seaborn")
        sns.set_style("whitegrid")
        fig, ax = plt.subplots()

        plt.ylim(self.corr_ylim_min, self.corr_ylim_max)

        plt.xscale("log")
        plt.yscale("log")

        plt.xlabel("1 / Loss", fontsize=self.label_font_size)
        plt.ylabel("Return", fontsize=self.label_font_size)

        # plt.title(self.game, fontsize=self.title_font_size)

        plt.scatter(xs, ys, s=80, color="#7b0d8a")
        plt.yticks(
            ticks=self.corr_yticks,
            labels=self.corr_ylabels,
            fontsize=self.tick_font_size,
        )
        plt.xticks(
            ticks=self.corr_xticks,
            labels=self.corr_xlabels,
            fontsize=self.tick_font_size,
        )
        ax.minorticks_off()

        # curve fit
        def power_law_plus_constant(x, A, B, C):
            # return A * x**B + C
            return A / (1 + B * x**C)

        initial_guesses = [5000, 0, -0.06]
        params, covariance = curve_fit(power_law_plus_constant, xs, ys, maxfev=10000, p0=initial_guesses)

        # Extract the fitted parameters
        A, B, C = params

        print(f"Fitted parameters: A = {A}, B = {B}, C = {C}")

        # Predict y values using the fitted parameters
        full_xs = xs
        y_fit = power_law_plus_constant(full_xs, *params)
        b_over_a = f"{B / A:.2e}"
        print("B / A", b_over_a)
        base = b_over_a.split('-')[0]
        if '-' in b_over_a:
            exp = f"-{int(b_over_a.split('-')[1])}"
        else:
            exp = f"{int(b_over_a.split('+')[1])}"
        b_over_a = f"{base}^{{{exp}}}"

        one_over_a = f"{1/A:.2e}"
        print("1 / A", one_over_a)
        base = one_over_a.split('-')[0]
        if '-' in one_over_a:
            exp = f"-{int(one_over_a.split('-')[1])}"
        else:
            exp = f"{int(one_over_a.split('+')[1])}"
        one_over_a = f"{base}^{{{exp}}}"

        plt.plot(full_xs, y_fit, "--", label=f'$R = ({b_over_a} \cdot (1/L)^{{{C:.2f}}} + {one_over_a})^{{{-1}}}$', linewidth=2.2, color='#7b0d8a')

        # plot regression line on log plot
        # label = f"$\log R = {lr.params[0]:.2f} - {lr.params[1]:.2f} \cdot \log L$"
        # self._plot_log_line(plt, lr, label=label, color="#7b0d8a")

        plt.legend(fontsize=self.legend_font_size, frameon=True, loc=self.legend_loc)

        plt.savefig(f"paper/figures/{self.plot_type}_vs_loss_scaling_law.pdf")
        plt.close()

        return (lr.params[0], lr.params[1])

    def _plot_log_line(self, plt: plt, lr, label=None, color=None, predict: bool = False):
        """
        Plot a line from slope and intercept, assuming log axes.
        """
        axes = plt.gca()
        start, stop = np.array(axes.get_xlim())
        x_vals = np.linspace(start=start, stop=stop)
        X = sm.add_constant(np.expand_dims(np.log(x_vals), 1))
        pred_results = lr.get_prediction(X)
        y_vals = pred_results.predicted_mean
        y_vals_upper = pred_results.summary_frame()["mean_ci_upper"].to_numpy()
        y_vals_lower = pred_results.summary_frame()["mean_ci_lower"].to_numpy()
        color = sns.color_palette()[0] if not color else color
        plt.plot(x_vals, np.exp(y_vals), "--", color=color, label=label, linewidth=2.2)
        axes.fill_between(
            x_vals, np.exp(y_vals_lower), np.exp(y_vals_upper), alpha=0.2, color=color
        )

    def _cross_validate(self, xs, ys, n_splits=10, max_clip=None):
        n_splits = max(len(xs) - 6, 2)
        # do cross validation
        tscv = TimeSeriesSplit(gap=0, max_train_size=None, n_splits=n_splits, test_size=1)
        rmses = []
        pred_ints = []
        beta_0 = []
        beta_1 = []
        num_in_sample = []
        for i, (train_idx, dev_idx) in enumerate(tscv.split(xs)):
            X_train = xs[train_idx]
            Y_train = ys[train_idx]
            X_dev = xs[dev_idx]
            Y_dev = ys[dev_idx]
            rmse, lr, is_in_pred_interval = self._fit_and_evaluate(X_train, Y_train, X_dev, Y_dev, max_clip)
            rmses.append(rmse)
            pred_ints.append(float(is_in_pred_interval))
            beta_0_ci = lr.conf_int()[0]
            beta_1_ci = lr.conf_int()[1]
            beta_0.append((beta_0_ci[0], lr.params[0], beta_0_ci[1]))
            beta_1.append((beta_1_ci[1], lr.params[1], beta_1_ci[1]))
            num_in_sample.append(len(train_idx))

        return np.mean(rmses), beta_0, beta_1, num_in_sample, np.mean(pred_ints)

    def _fit_and_evaluate(self, train_xs, train_ys, dev_xs, dev_ys, max_clip = None):
        train_xs = list(map(lambda x: float(x), train_xs))
        dev_xs = list(map(lambda x: float(x), dev_xs))

        # fit log-linear regression
        train_X, train_y = np.expand_dims(np.log(train_xs), 1), np.log(np.array(train_ys))
        lr = sm.OLS(train_y, sm.add_constant(train_X, has_constant='add')).fit()    

        # evaluate on dev
        dev_X, dev_y = np.expand_dims(np.log(dev_xs), 1), np.log(np.array(dev_ys))
        pred_results = lr.get_prediction(sm.add_constant(dev_X, has_constant='add'))
        if max_clip is None:
            y_vals = pred_results.predicted_mean
        else:
            y_vals = np.clip(pred_results.predicted_mean, None, np.log(max_clip))

        pi_lower = np.array(pred_results.summary_frame()["obs_ci_lower"])
        pi_upper = np.array(pred_results.summary_frame()["obs_ci_upper"])
        print(f'Prediction Interval: {np.exp(pi_lower).item():,} <= {np.exp(dev_y).item():,} <= {np.exp(pi_upper).item():,}')
        is_in_pred_interval = np.all((np.exp(dev_y) >= np.exp(pi_lower)) and (np.exp(dev_y) <= np.exp(pi_upper)))

        return np.sqrt(np.square(np.exp(dev_y) - np.exp(y_vals)).mean()), lr, is_in_pred_interval



@hydra.main(version_base=None, config_path="../../../conf", config_name="nethack_config")
def main(cfg: DictConfig):
    plot_cfg = cfg.loss if cfg.plot_type == "loss" else cfg.returns
  
    plotter = PowerLawPlotter(cfg.plot_type, plot_cfg.power_law)

    print("***** Plotting LOSS / RETURN vs. FLOPS *****")
    (l_b0, l_b1), (l_A, l_B, l_C) = plotter.plot(cross_validate=True, predict=True)

    print("***** Plotting PARAMETERS vs. FLOPS *****")
    (n_b0, n_b1) = plotter.plot_model(cross_validate=True, predict=True)

    print("***** Plotting SAMPLES vs. FLOPS *****")
    (d_b0, d_b1) = plotter.plot_samples(cross_validate=True, predict=True)

    print("***** Plotting RETURNS vs. 1 / LOSS *****")
    # (d_b0, d_b1) = plotter.plot_correlation(cross_validate=True)

    print()
    print("PREDICTIONS")
    print("-----------")
    print()
    
    if cfg.plot_type == "loss":
        # scaling laws from parametric fit
        d2_b0 = 6.19778850558802
        d2_b1 = 0.40998042306743543
        n2_b0 = -7.989547974816075
        n2_b1 = 0.5900195769325646

        # LOSS @ 40B
        print('LOSS @ 40B')
        NUM_SAMPLES = 40e9
        # 1. Get compute for NUM_SAMPLES
        c = np.exp((np.log(NUM_SAMPLES) - d_b0) / d_b1)
        c2 = np.exp((np.log(NUM_SAMPLES) - d2_b0) / d2_b1)
        print(f'C: {c:e}')
        print(f'C2: {c2:e}')

        # 2. Get model size for c
        n = np.exp(n_b0 + n_b1 * np.log(c))
        n2 = np.exp(n2_b0 + n2_b1 * np.log(c2))
        print(f'N: {n:e}')
        print(f'N2: {n2:e}')

        # 3. Get loss for c
        l = l_A * c**l_B + l_C
        l2 = l_A * c2**l_B + l_C
        print(f'L: {l:.5f}')
        print(f'L2: {l2:.5f}')
    else:
        # scaling laws from parametric fit
        d2_b0 = 6.845904805307816
        d2_b1 = 0.3947230743289811
        n2_b0 = -8.63766427453587
        n2_b1 = 0.6052769256710189

        # RETURN @ 40B
        NUM_SAMPLES = 40e9
        print('RETURN @ 40B')
        # 1. Get compute for NUM_SAMPLES
        c = np.exp((np.log(NUM_SAMPLES) - d_b0) / d_b1)
        c2 = np.exp((np.log(NUM_SAMPLES) - d2_b0) / d2_b1)
        print(f'C: {c:e}')
        print(f'C2: {c2:e}')

        # 2. Get model size for c
        n = np.exp(n_b0 + n_b1 * np.log(c))
        n2 = np.exp(n2_b0 + n2_b1 * np.log(c2))
        print(f'N: {n:e}')
        print(f'N2: {n2:e}')

        # 3. Get return for c
        r = np.exp(l_b0 + l_b1 * np.log(c))
        r2 = np.exp(l_b0 + l_b1 * np.log(c2))
        print(f'R: {r:.1f}')
        print(f'R2: {r2:.1f}')

        # 4. Get return 
        r3 = l_A * c**l_B + l_C
        print(f'R3: {r3:.3f}')
        print()

        # RETURN @ 55B
        NUM_SAMPLES = 55e9
        print('RETURN @ 55B')
        # 1. Get compute for NUM_SAMPLES
        c = np.exp((np.log(NUM_SAMPLES) - d_b0) / d_b1)
        print(f'C: {c:e}')

        # 2. Get model size for c
        n = np.exp(n_b0 + n_b1 * np.log(c))
        print(f'N: {n:e}')

        # 3. Get return for c
        r = np.exp(l_b0 + l_b1 * np.log(c))
        print(f'R: {r:.1f}')

        # 4. Get return 
        r2 = l_A * c**l_B + l_C
        print(f'R2: {r2:.3f}')
        print()

        # RETURN @10k
        TARGET_RETURN = 10_000
        print('RETURN @10k')
        # 1. Get compute for TARGET_RETURN
        c1 = np.exp((np.log(TARGET_RETURN) - l_b0) / l_b1)
        c2 = ((TARGET_RETURN - l_C)/l_A)**(1/l_B)
        # OPTION 3
        # A / (1 + B * x**C)
        # c2 = ((l_A/TARGET_RETURN -1)/l_B)**(1/l_C)

        # 2. Get model size for c
        n1 = np.exp(n_b0 + n_b1 * np.log(c1))
        n2 = np.exp(n_b0 + n_b1 * np.log(c2))

        # 3. Get samples for c
        d1 = np.exp(d_b0 + d_b1 * np.log(c1))
        d2 = np.exp(d_b0 + d_b1 * np.log(c2))

        print(f'C1: {c1:e}')
        print(f'N1: {n1:e}')
        print(f'D1: {d1:e}')
        print()

        print(f'C2: {c2:e}')
        print(f'N2: {n2:e}')
        print(f'D2: {d2:e}')
        print()


if __name__ == "__main__":
    main()