import os
from collections import defaultdict

import hydra
from omegaconf import DictConfig, OmegaConf
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression
import statsmodels.api as sm
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

EXP_DIR = 'paper/llama_iso_flop'

class IsoFLOPPlotter:
    def __init__(self, plot_type: str, plot_cfg):
        self.flop_to_run_path = {
            '1e14': [],
            '2e14': [],
            '5e14': [],
            '1e15': [],
            '2e15': [],
            '5e15': [],
            '1e16': [],
            '2e16': [],
            '5e16': [],
            '1e17': [],
            '2e17': [],
            '5e17': [],
            '1e18': [],
            '2e18': [],
            '3e18': [],
            '4e18': [],
            '5e18': [],
        }

        self.plot_type = plot_type
        self.plot_cfg = plot_cfg
        self.result_table = self._build_result_table(plot_type)

        self.tick_font_size = "17"
        self.legend_font_size = "14"
        self.title_font_size = "21"
        self.label_font_size = "20"
        self.legend_loc = plot_cfg.legend_loc
        self.ncol = plot_cfg.ncol

        self.yticks = plot_cfg.yticks
        self.ylabels = plot_cfg.ylabels
        self.ylim_min = plot_cfg.ylim_min
        self.ylim_max = plot_cfg.ylim_max
        self.flop_to_min_params = plot_cfg.flop_to_min_params
        self.flop_to_max_params = plot_cfg.flop_to_max_params

        for flop in plot_cfg.ignore_flops:
            del self.result_table[flop]

        # Cut off the parameters
        if not plot_cfg.plot_full:
            new_result_table = defaultdict(dict)
            for flop_budget in self.result_table.keys():
                params = sorted(self.result_table[flop_budget].keys())
                if flop_budget in plot_cfg.param_cut_offs:
                    params = params[
                        plot_cfg.param_cut_offs[flop_budget][0] : -plot_cfg.param_cut_offs[
                            flop_budget
                        ][1]
                    ]
                values = [self.result_table[flop_budget][param] for param in params]
                for param, value in zip(params, values):
                    new_result_table[flop_budget][param] = value
            self.result_table = new_result_table

        self.colors = list(
            reversed(
                sns.color_palette(plot_cfg.color_theme, len(self.result_table.keys()))
            )
        )

    def plot(self):
        cobb_douglas_lr, X = self._fit_cobb_douglas(self.result_table)
        b_c, b_n, b_d, b_n2, b_nd, b_d2 = cobb_douglas_lr.params
        z = (2*b_d2 - 2*b_nd + 2 * b_n2)
        d_exp = (2*b_n2 - b_nd)/z
        d_g = np.exp((b_n - b_d)/z)
        n_exp = (2*b_d2 - b_nd)/z
        n_g = np.exp((b_d - b_n)/z)

        print(f"\\alpha: {n_exp}")
        print(f"\\alpha constant: {np.log(n_g/6**n_exp)}")
        grad_coeff = [
            0, 0, 0,
            (0.5 * b_nd - b_d2)/(b_d2 - b_nd + b_n2)**2,
            (b_d2 - b_n2)/(2 * (b_d2 - b_nd + b_n2)**2),
            (-0.5 * b_nd + b_n2)/(b_d2 - b_nd + b_n2)**2
        ]
        grad_coeff = np.array(grad_coeff)
        se = np.sqrt(grad_coeff @ cobb_douglas_lr.cov_params() / X.shape[0] @ grad_coeff)
        print(f"\\alpha upper CI: {n_exp + 1.96 * se}")
        print(f"\\alpha lower CI: {n_exp - 1.96 * se}")

        print(f"\\beta: {d_exp}")
        print(f"\\beta constant: {np.log(d_g/6**d_exp)}")
        grad_coeff = [
            0, 0, 0,
            (-0.5 * b_nd + b_d2)/(b_d2 - b_nd + b_n2)**2,
            (b_n2 - b_d2)/(2 * (b_d2 - b_nd + b_n2)**2),
            (0.5 * b_nd - b_n2)/(b_d2 - b_nd + b_n2)**2
        ]
        grad_coeff = np.array(grad_coeff)
        se = np.sqrt(grad_coeff @ cobb_douglas_lr.cov_params() / X.shape[0] @ grad_coeff)
        print(f"\\beta upper CI: {d_exp + 1.96 * se}")
        print(f"\\beta lower CI: {d_exp - 1.96 * se}")

        plt.style.use("seaborn")
        sns.set_style("whitegrid")
        fig, ax = plt.subplots()

        plt.xscale("log")
        plt.yscale("log")

        plt.ylim(self.ylim_min, self.ylim_max)

        for i, flop_budget in enumerate(
            sorted(self.result_table.keys(), key=lambda x: float(x))
        ):
            params = sorted(self.result_table[flop_budget].keys())
            losses = [self.result_table[flop_budget][param] for param in params]
            plt.errorbar(
                params,
                losses,
                marker="o",
                markersize=9.2,
                linewidth=3.8,
                markeredgewidth=1.6,
                markeredgecolor="#F7F7FF",
                label=flop_budget,
                color=self.colors[i],
            )

        if not self.plot_cfg.plot_full:
            self._plot_per_parabola(self.result_table, self.colors)

        if self.plot_type == "return":
            # plot expert score
            plt.axhline(
                y=self.plot_cfg.expert_score,
                color="black",
                linestyle="--",
                label="Expert",
                alpha=1,
                linewidth=2.2,
            )

        plt.yticks(ticks=self.yticks, labels=self.ylabels, fontsize=self.tick_font_size)
        plt.xticks(fontsize=self.tick_font_size)
        ax.minorticks_off()

        plt.legend(
            fontsize=self.legend_font_size,
            frameon=True,
            loc=self.legend_loc,
            ncol=self.ncol,
        )

        # plt.title(self.game, fontsize=self.title_font_size)

        plt.xlabel("Parameters", fontsize=self.label_font_size)
        plt.ylabel(
            "Returns" if self.plot_type != "loss" else "Dev loss",
            fontsize=self.label_font_size,
        )
        plt.savefig(f"paper/figures/iso_flops_{self.plot_type}_vs_params.pdf")
        plt.close()

    def _plot_per_parabola(self, result_table, colors):
        """
        Fit parabola for each flop budget.
        """
        flop_to_data = self._flop_to_data(result_table)

        # Fit all flop budgets
        for i, flop_budget in enumerate(
            sorted(flop_to_data.keys(), key=lambda x: float(x))
        ):
            X = flop_to_data[flop_budget]["X"]
            y = flop_to_data[flop_budget]["y"]

            min_x = min(np.min(np.exp(X[:, -1])), self.flop_to_min_params[flop_budget])
            max_x = max(np.max(np.exp(X[:, -1])), self.flop_to_max_params[flop_budget])

            poly = PolynomialFeatures(2, include_bias=False)
            X = poly.fit_transform(X)
            reg = LinearRegression().fit(X, y)
            print(f"Per-parabola R^2 (flop: {flop_budget}):", reg.score(X, y))

            xs_aug = [
                [float(flop_budget), np.exp(param)]
                for param in np.linspace(np.log(min_x), np.log(max_x))
            ]
            X_aug = np.log(xs_aug)
            X_aug = poly.fit_transform(X_aug)

            y_predicted = reg.predict(X_aug)
            dataset = [
                (np.exp(x[1]), np.exp(y_pred)) for x, y_pred in zip(X_aug, y_predicted)
            ]
            xs = sorted(list(map(lambda x: x[0], dataset)))
            ys = list(map(lambda x: x[1], sorted(dataset, key=lambda x: x[0])))
            plt.plot(xs, ys, "--", linewidth=1.5, color=colors[i])

    def _flop_to_data(self, result_table):
        # Dataset per flop budget
        flop_to_data = defaultdict(dict)
        for flop_budget in result_table.keys():
            xs = []
            ys = []
            for param in result_table[flop_budget]:
                xs.append([float(flop_budget), param])
                ys.append(result_table[flop_budget][param])

            X, y = np.log(xs), np.log(np.array(ys))
            flop_to_data[flop_budget]["X"] = X
            flop_to_data[flop_budget]["y"] = y

        return flop_to_data

    def _build_result_table(self, plot_type: str):
        result_table = defaultdict(dict)
        
        if plot_type == 'loss':
            run_ids = os.listdir(EXP_DIR)
            for run_id in run_ids:
                run_dir = os.path.join(EXP_DIR, run_id)
                run_files = os.listdir(run_dir)
                for run_file in run_files:
                    for flop in self.flop_to_run_path.keys():
                        if run_file.endswith(f'{flop}.tar'):
                            run_path = os.path.join(run_dir, run_file)
                            data = torch.load(run_path)
                            result_table[flop][data['params']] = data['eval_loss']
                            break
        else:
            run_ids = os.listdir(EXP_DIR)
            for run_id in run_ids:
                run_dir = os.path.join(EXP_DIR, run_id)
                run_files = os.listdir(run_dir)
                for run_file in run_files:
                    for flop in self.flop_to_run_path.keys():
                        if run_file.endswith(f'{flop}.tar'):
                            run_path = os.path.join(run_dir, run_file)
                            data = torch.load(run_path)
                            flop = run_path.split('/')[-1].split('.')[0][-4:]
                            return_path_list = run_path.split('/')[:-1] + [f'eval_returns_model_{flop}.txt']
                            if os.path.exists(os.path.join(*return_path_list)):
                                with open(os.path.join(*return_path_list), 'r') as f:
                                    return_data = f.readlines()
                                    returns = [float(x.strip()) for x in return_data]
                            if data['params'] in result_table[flop]:
                                result_table[flop][data['params']].extend(returns)
                            else:
                                result_table[flop][data['params']] = returns
                            break

            # take mean
            for flop in result_table.keys():
                for param in result_table[flop].keys():
                    result_table[flop][param] = np.mean(result_table[flop][param])

        return result_table

    def _fit_cobb_douglas(self, loss_mean):
        """
        Fit quadratic Cobb Douglas function on all (flop, param, loss) 
        points.
        """
        xs = []
        ys = []
        c = 6
        for flop_budget in loss_mean.keys():
            for param in loss_mean[flop_budget]:
                xs.append([param, float(flop_budget)/(c * param)])
                ys.append(loss_mean[flop_budget][param])

        X, y = np.log(xs), np.log(np.array(ys))
        poly = PolynomialFeatures(2, include_bias=False)
        X = poly.fit_transform(X)
        lr = sm.OLS(y, sm.add_constant(X, has_constant='add')).fit()

        return lr, X

@hydra.main(version_base=None, config_path="../../../conf", config_name="nethack_config")
def main(cfg: DictConfig):
    plot_cfg = cfg.loss if cfg.plot_type == "loss" else cfg.returns

    plotter = IsoFLOPPlotter(cfg.plot_type, plot_cfg.iso_flop)
    plotter.plot()        
    
if __name__ == "__main__":
    main()