from collections import OrderedDict
import re
import os

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import torch
from tqdm.notebook import tqdm

from eval import get_run_metrics, read_run_dir, get_model_from_run, get_data_sampler, get_task_sampler, gen_standard, eval_batch
from plot_utils import basic_plot, collect_results, relevant_model_names

%matplotlib inline
%load_ext autoreload
%autoreload 2

sns.set_theme('notebook', 'darkgrid')
palette = sns.color_palette('colorblind')

run_dir = "../models"

df = read_run_dir(run_dir)
df  # list all the runs in our run_dir

task = "linear_regression"
#task = "sparse_linear_regression"
#task = "decision_tree"
#task = "relu_2nn_regression"

# zero-pad linear regression
run_id = "ca5857ab-ce77-47c8-96f6-a9aabc017485"

run_path = os.path.join(run_dir, task, run_id)
recompute_metrics = False

if recompute_metrics:
    get_run_metrics(run_path)  # these are normally precomputed at the end of training

def valid_row(r):
    return r.task == task and r.run_id == run_id

metrics = collect_results(run_dir, df, valid_row=valid_row)
_, conf = get_model_from_run(run_path, only_conf=True)
n_dims = conf.model.n_dims

# get model and check it on noisy data
from eval import build_evals

model, conf = get_model_from_run(run_path, step=500000)
model = model.cuda().eval()


def get_model_error(model, xs, ys, prefix_size=None, device='cuda'):
    pred = model(xs.to(device), ys.to(device)).detach()

    # but this is prediction for every example while varying number of in-context examples
    sq_error = (ys.cpu() - pred.cpu()).square()
    # pick only the prediction with 40 in-context examples
    final_pred_sq_error = sq_error[:, -1]

    mean_sq_error = final_pred_sq_error.mean()
    return mean_sq_error.item()


def get_linear_regression_error(xs, ys, prefix_size=None):
    # compare against linear regression with the same x,y pairs
    # treat each batch as a linear regression problem
    # but ignore the last sample
    linear_regression_pred = []
    for i in range(xs.shape[0]):
        w_hat = torch.linalg.pinv(xs[i][:-1]) @ ys[i][:-1]
        y_hat = xs[i] @ w_hat
        linear_regression_pred.append(y_hat)

    linear_regression_pred = torch.stack(linear_regression_pred)
    sq_error = (ys.cpu() - linear_regression_pred.cpu()).square()
    # pick only the prediction with 40 in-context examples
    final_pred_sq_error = sq_error[:, -1]
    mean_sq_error = final_pred_sq_error.mean()
    return mean_sq_error.item()


# numerically stable version of solving a system of equations
# using the default linalg.lstsq => least squares solver
def get_stable_linear_regression_error(xs, ys, prefix_size=None):
    # compare against linear regression with the same x,y pairs
    # treat each batch as a linear regression problem
    # but ignore the last sample
    linear_regression_pred = []
    worst_condition_number = torch.Tensor(1)
    for i in range(xs.shape[0]):
        w_hat = torch.linalg.lstsq(xs[i][:-1], ys[i][:-1]).solution
        y_hat = xs[i] @ w_hat
        linear_regression_pred.append(y_hat)
        worst_condition_number = max(worst_condition_number, torch.linalg.cond(xs[i][:-1]))

    linear_regression_pred = torch.stack(linear_regression_pred)
    sq_error = (ys.cpu() - linear_regression_pred.cpu()).square()
    # pick only the prediction with 40 in-context examples
    final_pred_sq_error = sq_error[:, -1]
    mean_sq_error = final_pred_sq_error.mean()
    return mean_sq_error.item(), worst_condition_number.item()


evaluation_kwargs = build_evals(conf)['standard']
data_name = evaluation_kwargs['data_name']
task_name = evaluation_kwargs['task_name']
batch_size = evaluation_kwargs['batch_size']
prompting_strategy = evaluation_kwargs['prompting_strategy']
n_points = evaluation_kwargs['n_points']

data_sampler = get_data_sampler(data_name, n_dims)

task_sampler = get_task_sampler(task_name, n_dims, batch_size)
generating_func = gen_standard
num_eval_examples=1280

std_devs_list = [0, 0.5, 1, 2, 4, 8, 16, 32, 64, 128]

model_error_list = {}
linear_regression_error_list = {}
error_ratio_list = {'condition_number': []}
for std_dev in std_devs_list:
    model_error_list['{}_noise'.format(std_dev)] = []
    linear_regression_error_list['{}_noise'.format(std_dev)] = []
    error_ratio_list['{}_noise'.format(std_dev)] = []

for i in range(num_eval_examples // batch_size):
    xs, xs_p = generating_func(data_sampler, n_points, batch_size)
    task = task_sampler()
    device = "cuda"

    ys = task.evaluate(xs)

    model_error = get_model_error(model, xs, ys)
    linear_regression_error, worst_condition_number = get_stable_linear_regression_error(xs, ys)
    model_error_list['0_noise'].append(model_error)
    linear_regression_error_list['0_noise'].append(linear_regression_error)
    error_ratio_list['0_noise'].append(linear_regression_error/model_error)
    error_ratio_list['condition_number'].append(worst_condition_number)

    print("Batch: {} | 40 in-context examples | Model Error: {} | Linear Regression Error: {}".format(i, model_error, linear_regression_error))

    # now let's try with noise
    for std in std_devs_list:
        if std == 0:
            # we've already done this
            continue
        noise = torch.randn_like(ys) * std
        noise[:, -1] = 0. # don't add noise for the query sample
        noisy_ys = ys + noise
        model_error = get_model_error(model, xs, noisy_ys)
        linear_regression_error, worst_condition_number = get_stable_linear_regression_error(xs, noisy_ys)
        print("Batch: {} | Noise with 0 mean {} std | Model Error: {} | Linear Regression Error: {}".format(i, std, model_error, linear_regression_error))

        model_error_list['{}_noise'.format(std)].append(model_error)
        linear_regression_error_list['{}_noise'.format(std)].append(linear_regression_error)
        error_ratio_list['{}_noise'.format(std)].append(linear_regression_error/model_error)

df_model = pd.DataFrame(model_error_list)
df_linear_regression = pd.DataFrame(linear_regression_error_list)
df_error_ratio = pd.DataFrame(error_ratio_list)
