from collections import OrderedDict
import re
import os

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import torch
from tqdm.notebook import tqdm
import numpy as np
import json
import pickle
import sys

sys.path.append('/root/weiminwu/in-context-learning-fork/in-context-learning/src/')

from eval import get_run_metrics, read_run_dir, get_model_from_run
from plot_utils import basic_plot, collect_results, relevant_model_names
from models import get_relevant_baselines, GDModel, DecisionTreeModel
from base_models import NeuralNetwork, ParallelNetworks

from samplers import get_data_sampler
from tasks import get_task_sampler


run_dir = "../models"
task = "decision_tree"
run_id = "6c67f151-9ba6-4eb1-84a0-b439a71c4545"
run_path = os.path.join(run_dir, task, run_id)

save_path = '/root/weiminwu/in-context-learning-fork/in-context-learning/src/evaluation/results/DT/exp_3'

def Square_Error(ys, pred):
    y_mean = torch.mean(ys)
    SS_tot = torch.sum((ys - y_mean) ** 2)
    SS_res = torch.sum((ys - pred) ** 2)
    R_square = 1 - SS_res / SS_tot
    return R_square

# compute similarity
def similarity(x, x_test):
    cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
    x = x.unsqueeze(0)
    x_test = x_test.unsqueeze(0)
    return cos(x, x_test)

model, conf = get_model_from_run(run_path)

n_dims = conf.model.n_dims
batch_size = conf.training.batch_size
data_sampler = get_data_sampler(conf.training.data, n_dims)
task_sampler = get_task_sampler(
    conf.training.task,
    n_dims,
    batch_size,
    **conf.training.task_kwargs
)

task = task_sampler()

# Part1: 100% accurate
n_batches = 3
prompt_length = 76
n_pairs = 1000
n_dims = 20
batch_size = 64

# best case

with open(save_path+'/data.pkl', 'rb') as file:
    data = pickle.load(file)
xs_list = data['x']
ys_list = data['y']
for time_idx in tqdm(range(n_batches)):
    xs_list[time_idx] = torch.from_numpy(xs_list[time_idx])
    ys_list[time_idx] = torch.from_numpy(ys_list[time_idx])

actual_points_best_prompt = [[] for _ in range(prompt_length)]
predicted_points_best_prompt = [[] for _ in range(prompt_length)]
for batch_idx in tqdm(range(n_batches)):
    # Sample 1101 points
    xs = data_sampler.sample_xs(b_size=batch_size, n_points=n_pairs+prompt_length)
    xs_list.append(xs)
    ys= task.evaluate(xs)
    ys_list.append(ys)
    batch_errors = []
    for j in range(1, prompt_length+1):
        prompt_xs = np.zeros((batch_size, prompt_length, n_dims))
        prompt_ys = np.zeros((batch_size, prompt_length))
        for batch_idx in range(batch_size):
            # test sample
            x_test, y_test = xs[batch_idx, n_pairs+j-1, :], ys[batch_idx, n_pairs+j-1]
            sims = torch.tensor([similarity(xs[batch_idx, i, :], x_test) for i in range(n_pairs)])
            selected_indices = torch.topk(sims, j-1, largest=True).indices
            prompt_xs[batch_idx, :j-1, :] = xs[batch_idx, selected_indices, :]
            prompt_ys[batch_idx, :j-1] = ys[batch_idx, selected_indices]
            prompt_xs[batch_idx, j-1, :] = x_test
            prompt_ys[batch_idx, j-1] = y_test
            
        prompt_xs = torch.from_numpy(prompt_xs).float()
        prompt_ys = torch.from_numpy(prompt_ys).float()
        with torch.no_grad():
            pred = model(prompt_xs, prompt_ys)
            
        actual_points_best_prompt[j-1].extend(prompt_ys[:, j-1])
        predicted_points_best_prompt[j-1].extend(pred[:, j-1])

    best = []

    for point_idx in range(prompt_length):
        actual = torch.tensor(actual_points_best_prompt[point_idx])
        predicted = torch.tensor(predicted_points_best_prompt[point_idx])
        R_square = Square_Error(actual, predicted)
        best.append(R_square)

    with open(save_path+'/best.txt', 'w') as f:
        for value in best:
            f.write(f"{value}\n")