from collections import OrderedDict
import re
import os
import copy

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import torch
from tqdm.notebook import tqdm

from eval import get_run_metrics, read_run_dir, get_model_from_run, get_data_sampler, get_task_sampler, gen_standard, eval_batch
from plot_utils import basic_plot, collect_results, relevant_model_names

%matplotlib inline
%load_ext autoreload
%autoreload 2

sns.set_theme('notebook', 'darkgrid')
palette = sns.color_palette('colorblind')

run_dir = "../models"

df = read_run_dir(run_dir)
df  # list all the runs in our run_dir

task = "linear_regression"
#task = "sparse_linear_regression"
#task = "decision_tree"
#task = "relu_2nn_regression"

# run_id = "810c3f7e-9ef7-4eb2-8aad-2bd656a4f964" # linear_regression_with_mlp_instead_of_attention
run_id = "d96d42cb-ffa2-408b-b7fd-cabc259b3c87"  # if you train more models, replace with the run_id from the table above

# noisy training
# run_id = "6c62e90e-9a15-4a8d-8ecf-cc4c6e20ac38" 

run_path = os.path.join(run_dir, task, run_id)
recompute_metrics = False

if recompute_metrics:
    get_run_metrics(run_path)  # these are normally precomputed at the end of training

def valid_row(r):
    return r.task == task and r.run_id == run_id

metrics = collect_results(run_dir, df, valid_row=valid_row)
_, conf = get_model_from_run(run_path, only_conf=True)
n_dims = conf.model.n_dims

# get model and check it on noisy data
from eval import build_evals

model, conf = get_model_from_run(run_path, step=500000)
model = model.cuda().eval()

def load_avg_attention_matrices(model):
    # take out the 12 layers of attention
    for layer_idx in range(12):
        attention_matrix = torch.load("attention_weights_layer_{}.pt".format(layer_idx))
        model._backbone.h[layer_idx].attn.mlp_conv.weight.data = attention_matrix
    return model


def get_model_error(model, xs, ys, prefix_size=None, device='cuda'):
    pred = model(xs.to(device), ys.to(device)).detach()

    # but this is prediction for every example while varying number of in-context examples
    sq_error = (ys.cpu() - pred.cpu()).square()
    # pick only the prediction with 40 in-context examples
    final_pred_sq_error = sq_error[:, -1]

    mean_sq_error = final_pred_sq_error.mean()
    return mean_sq_error.item()


def get_linear_regression_error(xs, ys, prefix_size=None):
    # compare against linear regression with the same x,y pairs
    # treat each batch as a linear regression problem
    # but ignore the last sample
    linear_regression_pred = []
    for i in range(xs.shape[0]):
        w_hat = torch.linalg.pinv(xs[i][:-1]) @ ys[i][:-1]
        y_hat = xs[i] @ w_hat
        linear_regression_pred.append(y_hat)

    linear_regression_pred = torch.stack(linear_regression_pred)
    sq_error = (ys.cpu() - linear_regression_pred.cpu()).square()
    # pick only the prediction with 40 in-context examples
    final_pred_sq_error = sq_error[:, -1]
    mean_sq_error = final_pred_sq_error.mean()
    return mean_sq_error.item()


# numerically stable version of solving a system of equations
# using the default linalg.lstsq => least squares solver
def get_stable_linear_regression_error(xs, ys, prefix_size=None):
    # compare against linear regression with the same x,y pairs
    # treat each batch as a linear regression problem
    # but ignore the last sample
    linear_regression_pred = []
    worst_condition_number = 1
    for i in range(xs.shape[0]):
        w_hat = torch.linalg.lstsq(xs[i][:-1], ys[i][:-1]).solution
        y_hat = xs[i] @ w_hat
        linear_regression_pred.append(y_hat)
        worst_condition_number = max(worst_condition_number, torch.linalg.cond(xs[i][:-1]))

    linear_regression_pred = torch.stack(linear_regression_pred)
    sq_error = (ys.cpu() - linear_regression_pred.cpu()).square()
    # pick only the prediction with 40 in-context examples
    final_pred_sq_error = sq_error[:, -1]
    mean_sq_error = final_pred_sq_error.mean()
    return mean_sq_error.item(), worst_condition_number.item()


evaluation_kwargs = build_evals(conf)['standard']
data_name = evaluation_kwargs['data_name']
task_name = evaluation_kwargs['task_name']
batch_size = evaluation_kwargs['batch_size']
prompting_strategy = evaluation_kwargs['prompting_strategy']
n_points = evaluation_kwargs['n_points']

data_sampler = get_data_sampler(data_name, n_dims)

task_sampler = get_task_sampler(task_name, n_dims, batch_size)
generating_func = gen_standard
num_eval_examples=1280

std_devs_list = [0]#, 0.5, 1, 2, 4, 8, 16, 32, 64, 128]

model_error_list = {}
linear_regression_error_list = {}
error_ratio_list = {'condition_number': []}
for std_dev in std_devs_list:
    model_error_list['{}_noise'.format(std_dev)] = []
    linear_regression_error_list['{}_noise'.format(std_dev)] = []
    error_ratio_list['{}_noise'.format(std_dev)] = []

# xs, xs_p = generating_func(data_sampler, n_points, batch_size)
# task = task_sampler()
# device = "cuda"

# ys = task.evaluate(xs)

# model_error = get_model_error(model, xs, ys)

# model._backbone.h[0].attn.attn_weights

# load avg attention weights
model = load_avg_attention_matrices(model)

for i in range(num_eval_examples // batch_size):
    xs, xs_p = generating_func(data_sampler, n_points, batch_size)
    task = task_sampler()
    device = "cuda"

    ys = task.evaluate(xs)

    model_error = get_model_error(model, xs, ys)
    linear_regression_error, worst_condition_number = get_stable_linear_regression_error(xs, ys)
    model_error_list['0_noise'].append(model_error)
    linear_regression_error_list['0_noise'].append(linear_regression_error)
    error_ratio_list['0_noise'].append(linear_regression_error/model_error)
    error_ratio_list['condition_number'].append(worst_condition_number)

    print("Batch: {} | 40 in-context examples | Model Error: {} | Linear Regression Error: {}".format(i, model_error, linear_regression_error))

    print("\n\n--------------------------------------------------------")
    print("With random permutations of the prefix")
    print("--------------------------------------------------------\n\n")

    # now let's try with random permutation
    rand_permuted_xs = copy.deepcopy(xs)
    rand_permuted_ys = copy.deepcopy(ys)
    for batch_idx in range(64):
        rand_permutation = torch.randperm(40)
        rand_permuted_xs[batch_idx, :-1] = xs[batch_idx, :-1][rand_permutation]
        rand_permuted_ys[batch_idx, :-1] = ys[batch_idx, :-1][rand_permutation]

    model_error = get_model_error(model, rand_permuted_xs, rand_permuted_ys)
    linear_regression_error, worst_condition_number = get_stable_linear_regression_error(rand_permuted_xs, rand_permuted_ys)
    model_error_list['0_noise'].append(model_error)
    linear_regression_error_list['0_noise'].append(linear_regression_error)
    error_ratio_list['0_noise'].append(linear_regression_error/model_error)
    error_ratio_list['condition_number'].append(worst_condition_number)

    print("Batch: {} | 40 in-context examples | Model Error: {} | Linear Regression Error: {}".format(i, model_error, linear_regression_error))

df_model = pd.DataFrame(model_error_list)
df_linear_regression = pd.DataFrame(linear_regression_error_list)
df_error_ratio = pd.DataFrame(error_ratio_list)

"""
# if I'm just averaging for the whole batch size of 64
# take out the 12 layers of attention
avg_attention_weights = []
for layer_idx in range(12):
    avg_attention_weights.append(model._backbone.h[layer_idx].attn.attn_weights[0][0] / model._backbone.h[layer_idx].attn.attn_weights[0][1])

for idx, attention_matrix in enumerate(avg_attention_weights):
    torch.save(attention_matrix, "attention_weights_layer_{}.pt".format(idx))
"""

"""
# this is the correct thing to do, where I average within a batch as well
"""
avg_attention_matrices = []
for layer_idx in range(12):
    attention_matrix = model._backbone.h[layer_idx].attn.attn_weights[0][0]
    num_passes = model._backbone.h[layer_idx].attn.attn_weights[0][1]

    per_batch_attention_matrix = torch.zeros_like(attention_matrix[0, :, :])
    for batch_idx in range(64):
        per_batch_attention_matrix += attention_matrix[batch_idx, :, :].detach()

    # average across batches
    per_batch_attention_matrix = per_batch_attention_matrix/(64 * num_passes)

    # replace each matrix with it.
    for batch_idx in range(64):
        attention_matrix[batch_idx, :, :] = per_batch_attention_matrix

    avg_attention_matrices.append(attention_matrix)
    torch.save(attention_matrix, "attention_weights_layer_{}.pt".format(layer_idx))
