import math
import numpy as np
from IPython import display
import torch
from scipy.stats import qmc

def myadd(list):
    if len(list)==1:
        return list[0]
    else:
        return np.add(*list)

class SALPlotsTime():

    figure_size_x = 4
    figure_size_y = 3

    def __init__(self,
                 fig,
                 xmin,xmax,
                 tstart,titer,
                 xpts,ypts,tpts,
                 ydata,fdata,
                 delta,
                 gps,
                 reality, current_reality, current_objective,
                 constraints,
                 device,current_time = 0.0,
                 model_input = None,
                 fixed_reality = False):

        #save parameters
        self.fig = fig
        self.xmin = xmin
        self.xmax = xmax
        self.tstart = tstart
        self.titer = titer
        self.xpts = xpts
        self.ypts = ypts
        self.tpts = tpts
        self.ydata = ydata
        self.fdata = fdata
        self.delta = delta
        self.gps = gps
        self.nrgps = self.gps.numoutputs
        self.reality = reality
        self.current_reality = current_reality
        self.current_objective = current_objective
        self.constraints = constraints
        self.device = device
        self.current_time = current_time
        self.nrinputs = self.xmax.shape[0]
        self.fixed_reality = fixed_reality

        # initialize remebering stuff
        self.old_fx = []
        self.old_fx_plots = []
        self.old_fstd = []
        self.old_fstd_plots = []
        self.old_model_constraints = []
        self.old_model_constraints_plots = []
        self.old_optimistic_model_constraints = []
        self.old_optimistic_model_constraints_plots = []
        self.old_computed_current_reality_plots = []
        self.old_computed_current_reality = []
        self.old_computed_objective_plots = []
        self.model_safe_stats_list = []
        self.model_optimistic_safe_stats_list = []
        self.model_errors_stats_list = []

        # initialize evaluation points
        self.model_input_plots = torch.cartesian_prod(*[torch.tensor(np.arange(self.xmin[i],self.xmax[i]+self.delta[i],self.delta[i]),device=self.device) for i in range(self.nrinputs)])
        sampler = qmc.Sobol(d=self.nrinputs, scramble=False)
        if model_input == None:
            self.model_input = self.xmin + (self.xmax-self.xmin) * sampler.random_base2(m=8)
        else:
            self.model_input = model_input
        self.compute_model_results()

        if self.nrinputs==2:
            self.X_arrange = np.arange(self.xmin[0], self.xmax[0]+self.delta[0], self.delta[0])
            self.Y_arrange = np.arange(self.xmin[1], self.xmax[1]+self.delta[1], self.delta[1])
            self.mesh_X, self.mesh_Y = np.meshgrid(self.X_arrange, self.Y_arrange)

        #initialize plots
        self.compute_initial_reality()
        real_constraints = [ self.computed_reality[:,i] for i in range(self.computed_reality.shape[-1]) ]
        self.real_constraints = [ c[0] + myadd([a*b for (a,b) in zip(c[1], real_constraints)]) for c in self.constraints ]
        real_constraints_plots = [ self.computed_reality_plots[:,i] for i in range(self.computed_reality_plots.shape[-1]) ]
        self.real_constraints_plots = [ c[0] + myadd([a*b for (a,b) in zip(c[1], real_constraints_plots)]) for c in self.constraints ]

        self.plot()

        self.display = display.display(self.fig, display_id=True)
        
        return None


    def update_points_and_models(self, xpts, ypts, tpts, ydata, fdata, gps, current_reality, current_objective, current_time = None):

        self.xpts = xpts
        self.ypts = ypts
        self.tpts = tpts
        self.ydata = ydata
        self.fdata = fdata
        self.gps = gps
        self.current_reality = current_reality
        self.current_objective = current_objective

        if isinstance(current_time, float) or isinstance(current_time, int):
            self.current_time = current_time
        else:
            self.current_time = self.current_time + 1
    
        self.compute_model_results()

        self.compute_statistics()

        self.plot()
        
        if self.display!= None:
            self.display.update(self.fig) # update the png

        return None


    def compute_initial_reality(self):

        self.computed_reality = torch.zeros(self.model_input.shape[0],self.nrgps,device=self.device)
        for i in range(self.model_input.shape[0]):
                self.computed_reality[i,:] = self.reality(*list(self.model_input[i,:].detach().cpu().numpy()))
        self.computed_reality = self.computed_reality.detach().cpu().numpy()

        self.computed_reality_plots = torch.zeros(self.model_input_plots.shape[0],self.nrgps,device=self.device)
        for i in range(self.model_input_plots.shape[0]):
                self.computed_reality_plots[i,:] = self.reality(*list(self.model_input_plots[i,:].detach().cpu().numpy()))
        self.computed_reality_plots = self.computed_reality_plots.detach().cpu().numpy()

        return None

    def compute_model_results(self):

        # time dependent model inputs
        t = torch.tensor([self.current_time], device=self.model_input.device, dtype=torch.float64)
        tt = t * torch.ones(self.model_input.shape[0], 1, device=self.model_input.device, dtype=torch.float64)
        self.model_input_t = torch.concatenate((tt,self.model_input),axis=1)
        t = torch.tensor([self.current_time], device=self.model_input_plots.device, dtype=torch.float64)
        tt = t * torch.ones(self.model_input_plots.shape[0], 1, device=self.model_input_plots.device, dtype=torch.float64)
        self.model_input_plots_t = torch.concatenate((tt,self.model_input_plots),axis=1)

        # evaluate models
        eval = self.gps(self.model_input_t)
        self.fx = [f.mean.cpu().detach().numpy() for f in eval]
        self.fstd = [torch.sqrt(f.variance+1e-6).cpu().detach().numpy() for f in eval]
        self.old_fx.append(self.fx)
        self.old_fstd.append(self.fstd)
        
        eval_plots = self.gps(self.model_input_plots_t)
        self.fx_plots = [f.mean.cpu().detach().numpy() for f in eval_plots]
        self.fstd_plots = [torch.sqrt(f.variance+1e-6).cpu().detach().numpy() for f in eval_plots]
        self.old_fx_plots.append(self.fx_plots)
        self.old_fstd_plots.append(self.fstd_plots)
        
        # evaluate constraints
        self.current_model_constraints = [ c[0] + myadd([a*b for (a,b) in zip(c[1],self.fx)]) + myadd([a*b for (a,b) in zip(c[2],self.fstd)]) for c in self.constraints ]
        self.old_model_constraints.append(self.current_model_constraints)
        self.current_optimistic_model_constraints = [ c[0] + myadd([a*b for (a,b) in zip(c[1],self.fx)]) for c in self.constraints ] # ignore the uncertainty
        self.old_optimistic_model_constraints.append(self.current_optimistic_model_constraints)
        self.current_model_constraints_plots = [ c[0] + myadd([a*b for (a,b) in zip(c[1],self.fx_plots)]) + myadd([a*b for (a,b) in zip(c[2],self.fstd_plots)]) for c in self.constraints ]
        self.old_model_constraints_plots.append(self.current_model_constraints_plots)
        self.current_optimistic_model_constraints_plots = [ c[0] + myadd([a*b for (a,b) in zip(c[1],self.fx_plots)]) for c in self.constraints ] # ignore the uncertainty
        self.old_optimistic_model_constraints_plots.append(self.current_optimistic_model_constraints_plots)

        if not self.fixed_reality or not hasattr(self, "computed_current_reality"):
            self.computed_current_reality = torch.zeros(self.model_input.shape[0], self.nrgps, device=self.device)
            for i in range(self.model_input.shape[0]):
                self.computed_current_reality[i,:] = self.current_reality(*list(self.model_input[i,:].detach().cpu().numpy()))
            self.computed_current_reality = self.computed_current_reality.detach().cpu().numpy()
        self.old_computed_current_reality.append(self.computed_current_reality)

        if not self.fixed_reality or not hasattr(self, "computed_current_reality_plots"):
            self.computed_current_reality_plots = torch.zeros(self.model_input_plots.shape[0], self.nrgps, device=self.device)
            for i in range(self.model_input_plots.shape[0]):
                self.computed_current_reality_plots[i,:] = self.current_reality(*list(self.model_input_plots[i,:].detach().cpu().numpy()))
            self.computed_current_reality_plots = self.computed_current_reality_plots.detach().cpu().numpy()
        self.computed_objective_plots = torch.zeros(self.model_input_plots.shape[0], device=self.device)
        for i in range(self.model_input_plots.shape[0]):
            self.computed_objective_plots[i] = self.current_objective(self.model_input_plots_t[i:(i+1),:])
        self.computed_objective_plots = self.computed_objective_plots.detach().cpu().numpy()
        self.old_computed_current_reality_plots.append(self.computed_current_reality_plots)
        self.old_computed_objective_plots.append(self.computed_objective_plots)
        
        self.current_real_constraints = [ self.computed_current_reality[:,i] for i in range(self.computed_current_reality.shape[-1]) ]
        self.current_real_constraints = [ c[0] + myadd([a*b for (a,b) in zip(c[1], self.current_real_constraints)]) for c in self.constraints ]

        self.current_real_constraints_plots = [ self.computed_current_reality_plots[:,i] for i in range(self.computed_current_reality_plots.shape[-1]) ]
        self.current_real_constraints_plots = [ c[0] + myadd([a*b for (a,b) in zip(c[1], self.current_real_constraints_plots)]) for c in self.constraints ]

        return None
    
    def and_over_list_of_arrays(self, list):
        if len(list)>1:
            return np.multiply(*list)
        else:
            return list[0]
        
    def or_over_list_of_arrays(self, list):
        if len(list)>1:
            return np.max(np.stack(list),axis=0)
        else:
            return list[0]

    def compute_statistics(self):

        safe_real = self.and_over_list_of_arrays( [ self.current_real_constraints[i]<=0 for i in range(len(self.constraints)) ] )
        safe_model = self.and_over_list_of_arrays( [ self.current_model_constraints[i]<=0 for i in range(len(self.constraints)) ] )
        safe_model_optimistic = self.and_over_list_of_arrays( [ self.current_optimistic_model_constraints[i]<=0 for i in range(len(self.constraints)) ] )
        unsafe_real = 1-safe_real
        unsafe_model = 1-safe_model
        unsafe_model_optimistic = 1-safe_model_optimistic
        
        model_safe_true_positive = np.sum(self.and_over_list_of_arrays([ safe_real, safe_model]))
        model_safe_true_negative = np.sum(self.and_over_list_of_arrays([ unsafe_real, unsafe_model]))
        model_safe_false_positive = np.sum(self.and_over_list_of_arrays([ unsafe_real, safe_model]))
        model_safe_false_negative = np.sum(self.and_over_list_of_arrays([ safe_real, unsafe_model]))
        model_safe_precision = model_safe_true_positive / (1.0 * model_safe_true_positive + model_safe_false_positive) if model_safe_true_positive + model_safe_false_positive > 0 else 0.0
        model_safe_recall = model_safe_true_positive / (1.0 * model_safe_true_positive + model_safe_false_negative) if model_safe_true_positive + model_safe_false_negative > 0 else 0.0
        if model_safe_precision + model_safe_recall<=0:
            model_safe_f1 = 0.0
        else:
            model_safe_f1 = 2 * model_safe_precision * model_safe_recall / (model_safe_precision + model_safe_recall)
        model_safe_accuracy = (model_safe_true_positive + model_safe_true_negative) / (1.0 * model_safe_true_positive + model_safe_true_negative + model_safe_false_positive + model_safe_false_negative)
        self.model_safe_stats_list.append(
            {
                "model_safe_precision": model_safe_precision,
                "model_safe_recall": model_safe_recall,
                "model_safe_f1": model_safe_f1,
                "model_safe_accuracy": model_safe_accuracy,
            }
        )

        model_optimistic_safe_true_positive = np.sum(self.and_over_list_of_arrays([ safe_real, safe_model_optimistic]))
        model_optimistic_safe_true_negative = np.sum(self.and_over_list_of_arrays([ unsafe_real, unsafe_model_optimistic]))
        model_optimistic_safe_false_positive = np.sum(self.and_over_list_of_arrays([ unsafe_real, safe_model_optimistic]))
        model_optimistic_safe_false_negative = np.sum(self.and_over_list_of_arrays([ safe_real, unsafe_model_optimistic]))
        model_optimistic_safe_precision = model_optimistic_safe_true_positive / (1.0 * model_optimistic_safe_true_positive + model_optimistic_safe_false_positive) if model_optimistic_safe_true_positive + model_optimistic_safe_false_positive > 0 else 0.0
        model_optimistic_safe_recall = model_optimistic_safe_true_positive / (1.0 * model_optimistic_safe_true_positive + model_optimistic_safe_false_negative) if model_optimistic_safe_true_positive + model_optimistic_safe_false_negative > 0 else 0.0
        if model_optimistic_safe_precision + model_optimistic_safe_recall<=0:
            model_optimistic_safe_f1 = 0.0
        else:
            model_optimistic_safe_f1 = 2 * model_optimistic_safe_precision * model_optimistic_safe_recall / (model_optimistic_safe_precision + model_optimistic_safe_recall)
        model_optimistic_safe_accuracy = (model_optimistic_safe_true_positive + model_optimistic_safe_true_negative) / (1.0 * model_optimistic_safe_true_positive + model_optimistic_safe_true_negative + model_optimistic_safe_false_positive + model_optimistic_safe_false_negative)
        self.model_optimistic_safe_stats_list.append(
            {
                "model_optimistic_safe_precision": model_optimistic_safe_precision,
                "model_optimistic_safe_recall": model_optimistic_safe_recall,
                "model_optimistic_safe_f1": model_optimistic_safe_f1,
                "model_optimistic_safe_accuracy": model_optimistic_safe_accuracy,
            }
        )

        dict = {}
        NLL_eval = self.gps.NLL_eval(self.model_input, self.computed_current_reality, self.current_time)
        for i in range(self.nrgps):
            dict.update({
                "RMSE_safe_area_model_"+str(i): math.sqrt(np.sum(np.square(self.computed_current_reality[:,i]-self.fx[i])*safe_real)/(1.0*np.sum(safe_real))),
                "RMSE_all_area_model_"+str(i): math.sqrt(np.sum(np.square(self.computed_current_reality[:,i]-self.fx[i]))/(1.0*self.fx[i].shape[0])),
                "NLL_train_"+str(i): self.gps.models[i].NLL,
                "NLL_eval_"+str(i): NLL_eval[i],
            })
        self.model_errors_stats_list.append(dict)

        return None

    def plot(self):

        plot_counter = 1

        axes = self.fig.get_axes()
        for axis in axes:
            self.fig.delaxes(axis)

        if self.nrinputs==2:
            axis = self.fig.add_subplot(self.figure_size_y, self.figure_size_x, plot_counter)
            axis.set_title('point positions with model constraints');
            for real_constraint in self.real_constraints_plots:
                axis.contour(self.mesh_X, self.mesh_Y, np.transpose(real_constraint.reshape(self.mesh_X.shape)), [0], colors = 'red')
            [ axis.contour(self.mesh_X, self.mesh_Y, np.transpose(c.reshape(self.mesh_X.shape)), [0], colors='green') for c in self.current_model_constraints_plots ]
            if len(self.old_model_constraints_plots)>1:
                [ axis.contour(self.mesh_X, self.mesh_Y, np.transpose(c.reshape(self.mesh_X.shape)), [0], colors='green', linestyles='dashed') for c in self.old_model_constraints_plots[-2] ]
            if len(self.old_model_constraints_plots)>2:
                [ axis.contour(self.mesh_X, self.mesh_Y, np.transpose(c.reshape(self.mesh_X.shape)), [0], colors='green', linestyles='dotted') for c in self.old_model_constraints_plots[-3] ]
            for j in range(5):
                if j==0:
                    axis.plot(self.xpts[-2:],self.ypts[-2:], color = str(1-0.5**(j+1)), linestyle='dotted')
                else:
                    axis.plot(self.xpts[-2-j:-j],self.ypts[-2-j:-j], color = str(1-0.5**(j+1)), linestyle='dotted')
            pointplot = axis.scatter(self.xpts, self.ypts, c=self.tpts, s=50, cmap="plasma", vmin=0, vmax=self.tstart + self.titer)
            for current_constraint in self.current_real_constraints_plots:
                axis.contour(self.mesh_X, self.mesh_Y, np.transpose(current_constraint.reshape(self.mesh_X.shape)), [0], colors = 'yellow')
            plot_counter = plot_counter + 1

        if self.nrinputs==2:
            axis = self.fig.add_subplot(self.figure_size_y, self.figure_size_x, plot_counter, projection='3d')
            axis.set_title('objective model')
            axis.plot_surface(self.mesh_X, self.mesh_Y, np.transpose(self.fx_plots[0].reshape(self.mesh_X.shape)), rstride=1, cstride=1, cmap='plasma')
            plot_counter = plot_counter + 1
        
        if self.nrinputs==2:
            axis = self.fig.add_subplot(self.figure_size_y, self.figure_size_x, plot_counter, projection='3d')
            axis.set_title('objective std')
            axis.plot_surface(self.mesh_X, self.mesh_Y, np.transpose(self.fstd_plots[0].reshape(self.mesh_X.shape)), rstride=1, cstride=1, cmap='plasma')
            plot_counter = plot_counter + 1

        if self.nrinputs==2:
            axis = self.fig.add_subplot(self.figure_size_y, self.figure_size_x, plot_counter)
            axis.set_title('objective for safe active learning')
            axis.contour(self.mesh_X, self.mesh_Y, np.transpose(self.computed_objective_plots.reshape(self.mesh_X.shape)), 10)
            [ axis.contour(self.mesh_X, self.mesh_Y, np.transpose(c.reshape(self.mesh_X.shape)), [0], colors='green') for c in self.current_model_constraints_plots ]
            if len(self.old_model_constraints_plots)>1:
                [ axis.contour(self.mesh_X, self.mesh_Y, np.transpose(c.reshape(self.mesh_X.shape)), [0], colors='green', linestyles='dashed') for c in self.old_model_constraints_plots[-2] ]
            if len(self.old_model_constraints)>2:
                [ axis.contour(self.mesh_X, self.mesh_Y, np.transpose(c.reshape(self.mesh_X.shape)), [0], colors='green', linestyles='dotted') for c in self.old_model_constraints_plots[-3] ]
            for j in range(5):
                if j==0:
                    axis.plot(self.xpts[-2:],self.ypts[-2:], color = str(1-0.5**(j+1)), linestyle='dotted')
                else:
                    axis.plot(self.xpts[-2-j:-j],self.ypts[-2-j:-j], color = str(1-0.5**(j+1)), linestyle='dotted')
            axis.scatter(self.xpts, self.ypts, c=self.tpts, s=50, cmap="plasma", vmin=0, vmax=self.tstart + self.titer)
            for real_constraint in self.real_constraints_plots:
                axis.contour(self.mesh_X, self.mesh_Y, np.transpose(real_constraint.reshape(self.mesh_X.shape)), [0], colors = 'red')
            for current_constraint in self.current_real_constraints_plots:
                axis.contour(self.mesh_X, self.mesh_Y, np.transpose(current_constraint.reshape(self.mesh_X.shape)), [0], colors = 'yellow')
            plot_counter = plot_counter + 1

        if self.nrinputs==2:
            axis = self.fig.add_subplot(self.figure_size_y, self.figure_size_x, plot_counter)
            axis.set_title('point positions with optimistic model constraints');
            for real_constraint in self.real_constraints_plots:
                axis.contour(self.mesh_X, self.mesh_Y, np.transpose(real_constraint.reshape(self.mesh_X.shape)), [0], colors = 'red')
            [ axis.contour(self.mesh_X, self.mesh_Y, np.transpose(c.reshape(self.mesh_X.shape)), [0], colors='green') for c in self.current_optimistic_model_constraints_plots ]
            if len(self.old_optimistic_model_constraints_plots)>1:
                [ axis.contour(self.mesh_X, self.mesh_Y, np.transpose(c.reshape(self.mesh_X.shape)), [0], linestyles='dashed', colors='green') for c in self.old_optimistic_model_constraints_plots[-2] ]
            if len(self.old_optimistic_model_constraints_plots)>2:
                [ axis.contour(self.mesh_X, self.mesh_Y, np.transpose(c.reshape(self.mesh_X.shape)), [0], linestyles='dotted', colors='green') for c in self.old_optimistic_model_constraints_plots[-3] ]
            for j in range(5):
                if j==0:
                    axis.plot(self.xpts[-2:],self.ypts[-2:], color = str(1-0.5**(j+1)), linestyle='dotted')
                else:
                    axis.plot(self.xpts[-2-j:-j],self.ypts[-2-j:-j], color = str(1-0.5**(j+1)), linestyle='dotted')
            pointplot = axis.scatter(self.xpts, self.ypts, c=self.tpts, s=50, cmap="plasma", vmin=0, vmax=self.tstart + self.titer)
            for current_constraint in self.current_real_constraints_plots:
                axis.contour(self.mesh_X, self.mesh_Y, np.transpose(current_constraint.reshape(self.mesh_X.shape)), [0], colors = 'yellow')
            plot_counter = plot_counter + 1

        if self.nrinputs==2:
            axis = self.fig.add_subplot(self.figure_size_y, self.figure_size_x, plot_counter, projection='3d')
            axis.set_title('constraint models')
            [ axis.plot_surface(self.mesh_X, self.mesh_Y, np.transpose(c.reshape(self.mesh_X.shape)), rstride=1, cstride=1, cmap='plasma') for c in self.current_model_constraints_plots ]
            plot_counter = plot_counter + 1
        
        axis = self.fig.add_subplot(self.figure_size_y, self.figure_size_x, plot_counter)
        axis.set_title('measurements (bullet) and reality (+) against safety bound')
        axis.scatter(self.tpts,self.ydata[:,0],c=self.tpts,cmap='plasma',marker='o', vmin=0, vmax=self.tstart + self.titer)
        axis.plot((np.zeros(self.ydata.shape[0])-self.constraints[0][0]),'r')
        axis.scatter(self.tpts,self.fdata[:,0],c=self.tpts,cmap='plasma',marker='+', vmin=0, vmax=self.tstart + self.titer)
        plot_counter = plot_counter + 1
        
        axis = self.fig.add_subplot(self.figure_size_y, self.figure_size_x, plot_counter)
        axis.set_title('health statistics with safety model')
        axis.set_ylim([0, 1])
        axis.plot([a['model_safe_precision'] for a in self.model_safe_stats_list ], color='red', label='precision')
        axis.plot([a['model_safe_recall'] for a in self.model_safe_stats_list ], color='blue', label='recall')
        axis.plot([a['model_safe_f1'] for a in self.model_safe_stats_list ], color='green', label='f1')
        axis.plot([a['model_safe_accuracy'] for a in self.model_safe_stats_list ], color='orange', label='accuracy')
        axis.legend(loc="lower left")
        plot_counter = plot_counter + 1

        axis = self.fig.add_subplot(self.figure_size_y, self.figure_size_x, plot_counter)
        axis.set_title('health statistics with optimistic model')
        axis.set_ylim([0, 1])
        axis.plot([a['model_optimistic_safe_precision'] for a in self.model_optimistic_safe_stats_list ], color='red', label='precision')
        axis.plot([a['model_optimistic_safe_recall'] for a in self.model_optimistic_safe_stats_list ], color='blue', label='recall')
        axis.plot([a['model_optimistic_safe_f1'] for a in self.model_optimistic_safe_stats_list ], color='green', label='f1')
        axis.plot([a['model_optimistic_safe_accuracy'] for a in self.model_optimistic_safe_stats_list ], color='orange', label='accuracy')
        axis.legend(loc="lower left")
        plot_counter = plot_counter + 1

        axis = self.fig.add_subplot(self.figure_size_y, self.figure_size_x, plot_counter)
        axis.set_title('model quality in RMSE')
        colors=['green','red','blue','cyan','olive','magenta']
        if len(self.model_errors_stats_list)>0:
            for i in range(int(len(list(self.model_errors_stats_list[0].keys()))/4)):
                axis.plot([a['RMSE_safe_area_model_'+str(i)] for a in self.model_errors_stats_list ], color=colors[i], label='model '+str(i)+' in safe area')
                axis.plot([a['RMSE_all_area_model_'+str(i)] for a in self.model_errors_stats_list ], color=colors[i], linestyle='dashed', label='model '+str(i)+' everywhere')
        axis.legend(loc="lower left")
        plot_counter = plot_counter + 1

        return None
