"""
    @author: ksreenivasan
    my own little helper script to get simple eval metrics
"""


from collections import OrderedDict
import re
import os

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import torch
from tqdm.notebook import tqdm

from eval import get_run_metrics, read_run_dir, get_model_from_run, get_data_sampler, get_task_sampler, gen_standard, eval_batch, build_evals
from plot_utils import basic_plot, collect_results, relevant_model_names


def get_model_error(model, xs, ys, prefix_size=None, device='cuda', layer_activations=None):
    if layer_activations is not None:
        layer_activations = [act.to(device) for act in layer_activations]

    pred = model.predict(xs.to(device), ys.to(device), layer_activations=layer_activations).detach()

    # but this is prediction for every example while varying number of in-context examples
    sq_error = (ys.cpu() - pred.cpu()).square()
    # pick only the prediction with 40 in-context examples
    final_pred_sq_error = sq_error[:, -1]

    mean_sq_error = final_pred_sq_error.mean()
    return mean_sq_error.item()

def get_linear_regression_error(xs, ys, prefix_size=None):
    # compare against linear regression with the same x,y pairs
    # treat each batch as a linear regression problem
    # but ignore the last sample
    linear_regression_pred = []
    for i in range(xs.shape[0]):
        w_hat = torch.linalg.pinv(xs[i][:-1]) @ ys[i][:-1]
        y_hat = xs[i] @ w_hat
        linear_regression_pred.append(y_hat)

    linear_regression_pred = torch.stack(linear_regression_pred)
    sq_error = (ys.cpu() - linear_regression_pred.cpu()).square()
    # pick only the prediction with 40 in-context examples
    final_pred_sq_error = sq_error[:, -1]
    mean_sq_error = final_pred_sq_error.mean()
    return mean_sq_error.item()


# numerically stable version of solving a system of equations
# using the default linalg.lstsq => least squares solver
def get_stable_linear_regression_error(xs, ys, prefix_size=None):
    # compare against linear regression with the same x,y pairs
    # treat each batch as a linear regression problem
    # but ignore the last sample
    linear_regression_pred = []
    worst_condition_number = torch.Tensor(1)
    for i in range(xs.shape[0]):
        w_hat = torch.linalg.lstsq(xs[i][:-1], ys[i][:-1]).solution
        y_hat = xs[i] @ w_hat
        linear_regression_pred.append(y_hat)
        worst_condition_number = max(worst_condition_number, torch.linalg.cond(xs[i][:-1]))

    linear_regression_pred = torch.stack(linear_regression_pred)
    sq_error = (ys.cpu() - linear_regression_pred.cpu()).square()
    # pick only the prediction with 40 in-context examples
    final_pred_sq_error = sq_error[:, -1]
    mean_sq_error = final_pred_sq_error.mean()
    return mean_sq_error.item(), worst_condition_number.item()

"""
    take a model and return a df with tests on different noise levels
"""
def get_evaluation_df(model, conf):
    evaluation_kwargs = build_evals(conf)['standard']
    data_name = evaluation_kwargs['data_name']
    task_name = evaluation_kwargs['task_name']
    batch_size = evaluation_kwargs['batch_size']
    prompting_strategy = evaluation_kwargs['prompting_strategy']
    n_points = evaluation_kwargs['n_points']
    n_dims = conf.model.n_dims

    data_sampler = get_data_sampler(data_name, n_dims)

    task_sampler = get_task_sampler(task_name, n_dims, batch_size)
    generating_func = gen_standard
    num_eval_examples=1280

    std_devs_list = [0, 0.5, 1, 2, 4, 8, 16, 32, 64, 128]

    model_error_list = {}
    linear_regression_error_list = {}
    error_ratio_list = {'condition_number': []}
    for std_dev in std_devs_list:
        model_error_list['{}_noise'.format(std_dev)] = []
        linear_regression_error_list['{}_noise'.format(std_dev)] = []
        error_ratio_list['{}_noise'.format(std_dev)] = []

    for i in range(num_eval_examples // batch_size):
        xs, xs_p = generating_func(data_sampler, n_points, batch_size)
        task = task_sampler()
        device = "cuda"

        if conf.training.task in ['relu_2nn_regression_chainofthought', 'relu_4nn_regression_chainofthought']:
            ys, layer_activations = task.evaluate(xs)
        else:
            ys = task.evaluate(xs)
            layer_activations = None

        model_error = get_model_error(model, xs, ys, layer_activations=layer_activations)
        linear_regression_error, worst_condition_number = get_stable_linear_regression_error(xs, ys)
        model_error_list['0_noise'].append(model_error)
        linear_regression_error_list['0_noise'].append(linear_regression_error)
        error_ratio_list['0_noise'].append(linear_regression_error/model_error)
        error_ratio_list['condition_number'].append(worst_condition_number)

        print("Batch: {} | 40 in-context examples | Model Error: {} | Linear Regression Error: {}".format(i, model_error, linear_regression_error))

        # now let's try with noise
        for std in std_devs_list:
            if std == 0:
                # we've already done this
                continue
            noise = torch.randn_like(ys) * std
            noise[:, -1] = 0. # don't add noise for the query sample
            noisy_ys = ys + noise
            model_error = get_model_error(model, xs, noisy_ys, layer_activations=layer_activations)
            linear_regression_error, worst_condition_number = get_stable_linear_regression_error(xs, noisy_ys)
            print("Batch: {} | Noise with 0 mean {} std | Model Error: {} | Linear Regression Error: {}".format(i, std, model_error, linear_regression_error))

            model_error_list['{}_noise'.format(std)].append(model_error)
            linear_regression_error_list['{}_noise'.format(std)].append(linear_regression_error)
            error_ratio_list['{}_noise'.format(std)].append(linear_regression_error/model_error)

    df_model = pd.DataFrame(model_error_list)
    df_linear_regression = pd.DataFrame(linear_regression_error_list)
    df_error_ratio = pd.DataFrame(error_ratio_list)

    print("Model error df:\n {}\n\n".format(df_model))
    print("Linear Regression error df:\n {}\n\n".format(df_linear_regression))
    print("Linear Regression error/ Model Error df:\n {}\n\n".format(df_linear_regression))


# efficient version of plotting script
# TODO: this won't work right now. Need to fix some things.
def plot_1d_function(model):
    # plot the 1D function for a given set of inputs
    xs, xs_p = generating_func(data_sampler, n_points, batch_size)
    # make xs just the same for every point
    xs[:] = xs[0]

    task = task_sampler()
    task.w_b = 0.5 * torch.ones_like(task.w_b)
    device = "cuda"
    ys = task.evaluate(xs)


    # my ground truth is basically 0.5 * x
    # let's see what the transformer fits
    x_list = []
    y_hat_list = []
    within_batch_idx = 0
    batch_idx = 0
    NUM_BATCHES=1000
    for idx in range(-batch_size*NUM_BATCHES, batch_size*NUM_BATCHES, 1):
        x_query = idx/100
        x_list.append(x_query)
        xs[within_batch_idx][-1, 0] = x_query
        within_batch_idx += 1
        if within_batch_idx >= batch_size:
            print("Processing Batch {} | idx={}".format(batch_idx, idx))
            # filled up a batch
            pred = model(xs.to(device), ys.to(device)).detach()
            y_hat_list += pred[:, -1].tolist()
            within_batch_idx = 0
            batch_idx += 1

    my_dict = {'x': x_list, 'y_hat': y_hat_list}
    my_df = pd.DataFrame(my_dict)
    # my_df.to_csv("results/1d_function_values_longprefix.csv", index=False)
    return my_df
