#!/usr/bin/env python
# coding: utf-8

# # Prerequises

# In[13]:


print("Loading dependencies", flush=True)
import numpy as np
import matplotlib.pyplot as plt
from mma import splx2bf, approx, from_dump
from classif_helper import *
import pickle


from sys import argv
start = int(argv[1])
stop = int(argv[2])
num = int(argv[3])
dataset = str(argv[4])
if dataset.isnumeric():
	with open(f"modules/synthetic{dataset}/iterator_{start}_{stop}_{num}.np", "rb") as f:
		iterator = np.load(f)
else:
	with open(f"modules/{dataset}/iterator_{start}_{stop}_{num}.np", "rb") as f:
		iterator = np.load(f)

# In[21]:
params = {}
params["box"] = [[0.,-0.1], [3.5,3]]
params["bandwidth"]=0.5
params["dimension"]=1
params["resolution"]=[50,50]
params["normalize"] = 1
params["ps"] = [0,1,2,np.inf]
params["colorbar"]=True
print(params)


# distance between images
distances = [
	lambda x,y : np.square(x-y).mean(),
	lambda x,y : np.square(x-y).mean()/y.max(),
	lambda x,y : np.abs(x-y).max(),
	lambda x,y : np.abs(x-y).max()/y.max(),
]
distances_names=["L2 norm", "scaled L2 norm", "sup norm", "scaled sup norm"]

print("Computing images...", flush=True)

print("- last imgs...", flush=True)
last_mod = from_dump(pickle.load(open(f"modules/synthetic{dataset}/module_{start}_{stop}_{num}_{len(iterator)-1}.pkl", "rb")))
last_imgs = []
for p in params["ps"]:
	plt.figure()
	last = last_mod.image(p=p,plot=True,**params)
	last_imgs.append(last)
	plt.savefig(f"test_{p}.png", dpi=200)
	plt.clf()
del last_mod

print("- approximation images...", flush=True)
errors = np.zeros(shape=(len(params["ps"]), len(iterator), len(distances)))
for j,_ in tqdm(enumerate(iterator), total=len(iterator)):
	current_mod = from_dump(pickle.load(open(f"modules/synthetic{dataset}/module_{start}_{stop}_{num}_{j}.pkl", "rb")))
	for i,p in enumerate(params["ps"]):
		current_img = current_mod.image(p=p, plot=False,**params)
		for k,d in enumerate(distances):
			errors[i,j,k] = d(current_img, last_imgs[i])

print("Saving errors...", flush=True)
with open(f"errors/synthetic{dataset}/errors_{start}_{stop}_{num}.pkl", 'wb') as file:
	pickle.dump(errors, file)

print("Saving plots...", flush=True)
for k,_ in enumerate(distances):
	plt.figure()
	for i,p in enumerate(params["ps"]):
		plt.plot(iterator[:-1], errors[i,:-1,k], label=f"p={p}")
	plt.xlabel("Number of points")
	plt.ylabel(distances_names[k])
	plt.legend()
	plt.savefig(f"images/synthetic{dataset}/plot_{distances_names[k]}_cv_synthetic_module_{start}_{stop}_{num}.svg")
	plt.clf()

print("Done !")




