from collections import OrderedDict
import re
import os

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import torch
from tqdm.notebook import tqdm
import numpy as np
import json
import pickle
import sys

sys.path.append('/root/weiminwu/in-context-learning-fork/in-context-learning/src/')

from eval import get_run_metrics, read_run_dir, get_model_from_run
from plot_utils import basic_plot, collect_results, relevant_model_names
from models import get_relevant_baselines, GDModel, DecisionTreeModel
from base_models import NeuralNetwork, NeuralNetwork3, NeuralNetwork4, NeuralNetwork6, ParallelNetworks

from samplers import get_data_sampler
from tasks import get_task_sampler


run_dir = "../models"
# task = "relu_3nn_regression"
# run_id = "b98285ff-feda-4a04-8914-3810a4bec3f3"
# save_path = '/root/weiminwu/in-context-learning-fork/in-context-learning/src/evaluation/results/3NN/exp_1'
# task = "relu_3nn_regression_relu"
# run_id = "b98285ff-feda-4a04-8914-3810a4bec3f3"
# save_path = '/root/weiminwu/in-context-learning-fork/in-context-learning/src/evaluation/results/3NN_r/exp_1'
# task = "relu_4nn_regression"
# run_id = "b98285ff-feda-4a04-8914-3810a4bec3f3"
# save_path = '/root/weiminwu/in-context-learning-fork/in-context-learning/src/evaluation/results/4NN/exp_1'
# task = "relu_4nn_regression_relu"
# run_id = "b98285ff-feda-4a04-8914-3810a4bec3f3"
# save_path = '/root/weiminwu/in-context-learning-fork/in-context-learning/src/evaluation/results/4NN_r/exp_1'
task = "relu_6nn_regression"
name_model = "relu_6nn_regression"
run_id = "b98285ff-feda-4a04-8914-3810a4bec3f3"
save_path = '/root/weiminwu/in-context-learning-fork/in-context-learning/src/evaluation/results/6NN/exp_1'
# task = "relu_6nn_regression_relu"
# run_id = "b98285ff-feda-4a04-8914-3810a4bec3f3"
# save_path = '/root/weiminwu/in-context-learning-fork/in-context-learning/src/evaluation/results/6NN_r/exp_1'
run_path = os.path.join(run_dir, task, run_id)



if not os.path.exists(save_path):
    os.makedirs(save_path)
    print(f"Directory created: {save_path}")
else:
    print(f"Directory already exists: {save_path}")

def Square_Error(ys, pred):
    y_mean = torch.mean(ys)

    SS_tot = torch.sum((ys - y_mean) ** 2)
    SS_res = torch.sum((ys - pred) ** 2)
    R_square = 1 - SS_res / SS_tot
    return R_square

model, conf = get_model_from_run(run_path)

n_dims = conf.model.n_dims
batch_size = conf.training.batch_size
data_sampler = get_data_sampler(conf.training.data, n_dims)
task_sampler = get_task_sampler(
    conf.training.task,
    n_dims,
    batch_size,
    **conf.training.task_kwargs
)

task = task_sampler()

# Part1: W1=1, W2=0
n_batches = 3
prompt_length = 76
# # save xs and ys for the base_line
xs_list = [] 
ys_list = [] 
actual_points_1 = [[] for _ in range(prompt_length)]
predicted_points_1 = [[] for _ in range(prompt_length)]
# Generate data and perform the experiment
for _ in tqdm(range(n_batches)):

    xs = data_sampler.sample_xs(b_size=batch_size, n_points=prompt_length)
    xs_list.append(np.array(xs))
    ys = task.evaluate(xs)
    ys_list.append(np.array(ys))

data = {'x':xs_list, 'y':ys_list}
with open(save_path+'/data.pkl', 'wb') as pkl_file:
    pickle.dump(data, pkl_file)
    
with open(save_path+'/data.pkl', 'rb') as file:
    data = pickle.load(file)

xs_list = data['x']
ys_list = data['y']
for time_idx in tqdm(range(n_batches)):
    xs_list[time_idx] = torch.from_numpy(xs_list[time_idx])
    ys_list[time_idx] = torch.from_numpy(ys_list[time_idx])

for time_idx in tqdm(range(n_batches)):
    xs = xs_list[time_idx]
    ys = ys_list[time_idx]
    
    with torch.no_grad():
        pred = model(xs, ys)
    for j in range(prompt_length):
        # actual_points_1[j].extend(ys[:, j])
        # predicted_points_1[j].extend(pred[:, j])
        actual_points_1[j].append(list(np.array(ys[:, j])))
        predicted_points_1[j].append(list(np.array(pred[:, j])))

w_1_error = [[] for _ in range(prompt_length)]
means = []
stds = []
for point_idx in range(prompt_length):
    R_square = []
    for time_idx in tqdm(range(n_batches)):
        actual = torch.tensor(actual_points_1[point_idx][time_idx])
        predicted = torch.tensor(predicted_points_1[point_idx][time_idx])
        R_square.append(Square_Error(actual, predicted))
    w_1_error[point_idx].append(R_square)
    
for errors in w_1_error:
    # Convert the sublist into a NumPy array
    errors_array = np.array(errors)
    
    # Calculate the mean and std of the sublist
    mean = np.mean(errors_array)
    std = np.std(errors_array)
    
    # Append the results to the lists
    means.append(mean)
    stds.append(std)
    
# Creating a dictionary to store both list
data = {'means': means, 'stds': stds}

# Writing the dictionary to a file in JSON format
with open(save_path+'/w_1.pkl', 'wb') as pkl_file:
    print(data)
    pickle.dump(data, pkl_file)
    
# with open(save_path+'/w_1.pkl', 'rb') as file:
#     data = pickle.load(file)

#==========================================================
# W1 = 0.9, w2 = 0.1

actual_points_w_9 = [[] for _ in range(prompt_length)]
predicted_points_w_9 = [[] for _ in range(prompt_length)]
for _ in tqdm(range(n_batches)):
    xs = data_sampler.sample_xs(b_size=batch_size, n_points=prompt_length, w1=0.9, w2=0.1)
    ys = task.evaluate(xs)

    with torch.no_grad():
        pred = model(xs, ys)
    for j in range(prompt_length):
        actual_points_w_9[j].append(list(np.array(ys[:, j])))
        predicted_points_w_9[j].append(list(np.array(pred[:, j])))
        
w_9 = [[] for _ in range(prompt_length)]
means = []
stds = []

for point_idx in range(prompt_length):
    R_square = []
    for time_idx in tqdm(range(n_batches)):
        actual = torch.tensor(actual_points_w_9[point_idx][time_idx])
        predicted = torch.tensor(predicted_points_w_9[point_idx][time_idx])
        R_square.append(Square_Error(actual, predicted))
    w_9[point_idx].append(R_square)
        
for errors in w_9:
    # Convert the sublist into a NumPy array
    errors_array = np.array(errors)
    
    # Calculate the mean and std of the sublist
    mean = np.mean(errors_array)
    std = np.std(errors_array)
    
    # Append the results to the lists
    means.append(mean)
    stds.append(std)
    
# Creating a dictionary to store both list
data = {'means': means, 'stds': stds}

# Writing the dictionary to a file in JSON format
with open(save_path+'/w_9.pkl', 'wb') as pkl_file:
    print(data)
    pickle.dump(data, pkl_file)
        

        
# W1 = 0.7, w2 = 0.3

actual_points_w_7 = [[] for _ in range(prompt_length)]
predicted_points_w_7 = [[] for _ in range(prompt_length)]
for _ in tqdm(range(n_batches)):
    xs = data_sampler.sample_xs(b_size=batch_size, n_points=prompt_length, w1=0.7, w2=0.3)
    ys = task.evaluate(xs)

    with torch.no_grad():
        pred = model(xs, ys)
    for j in range(prompt_length):
        actual_points_w_7[j].append(list(np.array(ys[:, j])))
        predicted_points_w_7[j].append(list(np.array(pred[:, j])))
        

w_7 = [[] for _ in range(prompt_length)]

means = []
stds = []

for point_idx in range(prompt_length):
    R_square = []
    for time_idx in tqdm(range(n_batches)):
        actual = torch.tensor(actual_points_w_7[point_idx][time_idx])
        predicted = torch.tensor(predicted_points_w_7[point_idx][time_idx])
        R_square.append(Square_Error(actual, predicted))
    w_7[point_idx].append(R_square)
        
for errors in w_7:
    # Convert the sublist into a NumPy array
    errors_array = np.array(errors)
    
    # Calculate the mean and std of the sublist
    mean = np.mean(errors_array)
    std = np.std(errors_array)
    
    # Append the results to the lists
    means.append(mean)
    stds.append(std)
    
# Creating a dictionary to store both list
data = {'means': means, 'stds': stds}

# Writing the dictionary to a file in JSON format
with open(save_path+'/w_7.pkl', 'wb') as pkl_file:
    print(data)
    pickle.dump(data, pkl_file)
        
# Baseline
baselines = {
        "relu_2nn_regression": [
            (
                GDModel,
                {
                    "model_class": NeuralNetwork,
                    "model_class_args": {
                        "in_size": 20,
                        "hidden_size": 100,
                        "out_size": 1,
                    },
                    "opt_alg": "adam",
                    "batch_size": 64,
                    "lr": 5e-3,
                    "num_steps": 100,
                },
            ),
        ],
        "relu_3nn_regression": [
            (
                GDModel,
                {
                    "model_class": NeuralNetwork3,
                    "model_class_args": {
                        "in_size": 20,
                        "hidden_size": 100,
                        "out_size": 1,
                    },
                    "opt_alg": "adam",
                    "batch_size": 64,
                    "lr": 5e-3,
                    "num_steps": 100,
                },
            ),
        ],
        "relu_4nn_regression": [
            (
                GDModel,
                {
                    "model_class": NeuralNetwork4,
                    "model_class_args": {
                        "in_size": 20,
                        "hidden_size": 100,
                        "out_size": 1,
                    },
                    "opt_alg": "adam",
                    "batch_size": 64,
                    "lr": 5e-3,
                    "num_steps": 100,
                },
            ),
        ],
        "relu_6nn_regression": [
            (
                GDModel,
                {
                    "model_class": NeuralNetwork6,
                    "model_class_args": {
                        "in_size": 20,
                        "hidden_size": 100,
                        "out_size": 1,
                    },
                    "opt_alg": "adam",
                    "batch_size": 64,
                    "lr": 5e-3,
                    "num_steps": 100,
                },
            ),
        ],
        "decision_tree": [
            (DecisionTreeModel, {"max_depth": 4}),
            # (DecisionTreeModel, {"max_depth": None}),
        ],
    }

baseline = [model_cls(**kwargs) for model_cls, kwargs in baselines[conf.training.task]]

actual_points_base = [[] for _ in range(prompt_length)]
predicted_points_base = [[] for _ in range(prompt_length)]
all_errors_base = []

with open(save_path+'/data.pkl', 'rb') as file:
    data = pickle.load(file)
    
xs_list = data['x']
ys_list = data['y']

for time_idx in tqdm(range(n_batches)):
    xs = torch.from_numpy(xs_list[time_idx])
    ys = torch.from_numpy(ys_list[time_idx])
    pred = baseline[0](xs, ys)
    
    for j in range(prompt_length):
        actual_points_base[j].append(list(np.array(ys[:, j])))
        predicted_points_base[j].append(list(np.array(pred[:, j].cpu())))
        
base_error = [[] for _ in range(prompt_length)]

means = []
stds = []

for point_idx in range(prompt_length):
    R_square = []
    for time_idx in tqdm(range(n_batches)):
        actual = torch.tensor(actual_points_base[point_idx][time_idx])
        predicted = torch.tensor(predicted_points_base[point_idx][time_idx])
        R_square.append(Square_Error(actual, predicted))
    base_error[point_idx].append(R_square)
    
for errors in base_error:
    # Convert the sublist into a NumPy array
    errors_array = np.array(errors)
    
    # Calculate the mean and std of the sublist
    mean = np.mean(errors_array)
    std = np.std(errors_array)
    
    # Append the results to the lists
    means.append(mean)
    stds.append(std)

data = {'means': means, 'stds': stds}

# Writing the dictionary to a file in JSON format
with open(save_path+'/baseline.pkl', 'wb') as pkl_file:
    print(data)
    pickle.dump(data, pkl_file)

