from math import e
import time
from tracemalloc import start
from turtle import st
from typing import final
from attr import s
from matplotlib.artist import kwdoc
import numpy as np
import einops
import torch
from torch._functorch.eager_transforms import jacrev, jacfwd
from torch import vmap
from torch.utils.data import DataLoader
from tqdm import trange
import logging

import wandb
from GINN.shape_boundary_helper import ShapeBoundaryHelper
from data.simjeb_dataset import SimJebDataset
from evaluation.jeb_meter import JebMeter
from evaluation.simple_obst_meter import SimpleObstacleMeter
from notebooks.notebook_utils import plot_poly_or_multipolygon
from train.losses import closest_shape_diversity_loss, dirichlet_loss, expression_curvature_loss, strain_curvature_loss, interface_loss, eikonal_loss, envelope_loss, l1_loss, mse_loss, normal_loss_euclidean, obstacle_interior_loss
from train.opt.nys_newton_cg import NysNewtonCG
from train.opt.opt_util import get_opt, opt_step
from train.opt.adaptive_util import get_lambda_dict, grad_norm_sub_losses, lambda_balancing, scale_losses, get_initial_nu_dict, adaptive_penalty_update
from train.train_utils.autoclip import AutoClip
from models.model_utils import tensor_product_xz

from GINN.visualize.plotter_2d import Plotter2d
from GINN.visualize.plotter_3d import Plotter3d
from GINN.persistent_homology.ph_plotter import PHPlotter
from GINN.persistent_homology.ph_manager import PHManager
from GINN.problem_sampler import ProblemSampler
from GINN.helpers.timer_helper import TimerHelper
from train.train_utils.latent_sampler import get_z_corners, sample_z
from utils import get_model, get_stateless_net_with_partials, is_every_n_epochs_fulfilled, load_model_optim_sched, save_model_every_n_epochs, set_and_true
import cripser

class Trainer():
    
    def __init__(self, config, model, mp_manager) -> None:
        self.config = config
        self.model = model
        self.mpm = mp_manager
        self.logger = logging.getLogger('trainer')
        self.device = config['device']
        self.netp = get_stateless_net_with_partials(self.model, nz=self.config['nz'])
        self.p_sampler = ProblemSampler(self.config)

        self.timer_helper = TimerHelper(self.config, lock=mp_manager.get_lock())
        self.plotter = Plotter2d(self.config) if config['nx']==2 else Plotter3d(self.config)
        self.mpm.set_timer_helper(self.timer_helper)  ## weak circular reference
        self.ph_manager = PHManager(self.config, self.model)
        self.ph_plotter = PHPlotter(self.config) #if config['ph_loss'] else None

        self.shape_boundary_helper = ShapeBoundaryHelper(self.config, self.netp,self.mpm, self.plotter, self.timer_helper, 
                                                         self.p_sampler.sample_from_interface()[0], self.device)
        self.auto_clip = AutoClip(config)
        
        # curvature loss
        self.k_theta_gradnorm_fn = self.get_k_theta_gradnorm_func()
        
        # data
        if ('lambda_data' in config) and (config['lambda_data'] > 0):
            # self.dataloader = DataLoader(SimJebDataset(config), ginn_bsize=config['data_bsize'], shuffle=True, num_workers=config['dataloader_num_workers'])
                                        #  generator=torch.Generator(device=config['device']))
            self.data = SimJebDataset(config)
            self.shuffle_idcs = torch.randperm(len(self.data), device=self.device)
            self.cur_data_idx = 0
        
        self.meter = None
        if config['problem'] in ['simple_2d', 'double_obstacle']:
            self.meter = SimpleObstacleMeter.create_from_problem_sampler(config, self.p_sampler)
        elif config['problem'] == 'simjeb':
            self.meter = JebMeter(config)
            self.config['n_points_envelope'] = self.config['n_points_envelope'] // 3 * 3

        self.p_surface = None
        self.weights_surf_pts = None
        self.cur_plot_epoch = 0
        self.log_history_dict = {}
        
        self.sub_grad_norm_dict = {}
        self.lambda_dict = get_lambda_dict(self.config)
        self.nu_dict = get_initial_nu_dict(self.lambda_dict)
        self.init_loss_dispatcher()
        self.init_aug_lagr_lambas()
        self.init_objective_loss()

    def train(self):
        
        # get optimizer and scheduler
        opt = get_opt(self.config['opt'], self.config, self.model.parameters(), self.aug_lagr_lambda_dict)
        sched = None
        if set_and_true('use_scheduler', self.config):
            def warm_and_decay_lr_scheduler(step: int):
                return self.config['scheduler_gamma'] ** (step / self.config['decay_steps'])
            sched = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda=warm_and_decay_lr_scheduler)
        # maybe load model, optimizer and scheduler
        self.model, opt, sched = load_model_optim_sched(self.config, self.model, opt, sched)
        
        # get z
        z = sample_z(self.config, epoch=0, previous_z=None)
        z_corners = get_z_corners(self.config)
        self.z_corners = z_corners
        self.logger.info(f'Initial z: {z}')        
        self.logger.info(f'z_corners: {z_corners}')
        if self.config['train_plot_n_shapes'] < len(z) + len(z_corners):
            print(f'WARNING: plotting only {self.config["train_plot_n_shapes"]} out of {len(z) + len(z_corners)} shapes, namely z_corner[0], zs, z_corner[1] (if available)')
        
        # training/validation loop
        for epoch in (pbar := trange(self.config['max_epochs'], leave=True, position=0, colour="yellow")):
            cur_log_dict = {}
            self.mpm.update_epoch(epoch)

            batch = None
            if self.config['lambda_data'] > 0:
                batch = self._get_data_batch()

            ## validation
            if epoch > 0 and is_every_n_epochs_fulfilled(epoch, self.config, 'valid_every_n_epochs'):
                self.model.eval()
                # get validation z
                z_val = sample_z(self.config, epoch, previous_z=None, is_validation=True)
                # plot validation shapes
                self.plotter.reset_output(self.p_sampler.recalc_output(self.netp.f_, self.netp.params_, z_val), epoch=epoch)
                self.mpm.plot(self.plotter.plot_shape, 'val_plot_shape', arg_list=[self.p_sampler.constr_pts_dict, 'Val Boundary'])
                # compute validation metrics
                if epoch > 0 and is_every_n_epochs_fulfilled(epoch, self.config, 'val_shape_metrics_every_n_epochs'):
                    mesh_or_contour = self.p_sampler.get_mesh_or_contour(self.netp.f_, self.netp.params_, z_val)
                    if mesh_or_contour is not None:
                        self.mpm.metrics(self.meter.get_average_metrics_as_dict, arg_list=[mesh_or_contour], kwargs_dict={'prefix': 'm_val_'})

            ## training
            self.model.train()
            z = sample_z(self.config, epoch, z)
            
            ## reset base shape if needed
            if self.plotter.do_plot():
                z_plot = z
                if self.config['lambda_data'] > 0:
                    # sub-sample z to only train_plot_n_shapes
                    z_sub = z[:self.config['train_plot_n_shapes'] - len(z_corners)] if self.config['train_plot_n_shapes'] < len(z) - len(z_corners) else z
                    z_plot = torch.cat([z_corners[:1], z_sub, z_corners[1:]], dim=0) # works if there are 1 or 2 data shapes in the corners
                self.plotter.reset_output(self.p_sampler.recalc_output(self.netp.f_, self.netp.params_, z_plot), epoch=epoch)
                self.mpm.plot(self.plotter.plot_shape, 'plot_shape', arg_list=[self.p_sampler.constr_pts_dict])

                if self.ph_manager is not None:
                    PH = self.ph_manager.calc_ph(z)
                    PH = [ph.get() for ph in PH]
                    # subsample to train_plot_n_shapes
                    PH = PH[:self.config['train_plot_n_shapes']] if self.config['train_plot_n_shapes'] < len(PH) else PH
                    self.mpm.plot(self.ph_plotter.plot_ph_diagram, 'plot_ph_diagram', arg_list=[PH], kwargs_dict={})

            ## compute metrics _before_ taking the GD step
            if epoch > 0 and is_every_n_epochs_fulfilled(epoch, self.config, 'shape_metrics_every_n_epochs'):
                mesh_or_contour = self.p_sampler.get_mesh_or_contour(self.netp.f_, self.netp.params_, z)
                if mesh_or_contour is not None:
                    self.mpm.metrics(self.meter.get_average_metrics_as_dict, arg_list=[mesh_or_contour], kwargs_dict={'prefix': 'm_train_'})
                
                if self.config['lambda_data'] > 0:
                    mesh_or_contour = self.p_sampler.get_mesh_or_contour(self.netp.f_, self.netp.params_, z_corners)
                    if mesh_or_contour is not None:
                        self.mpm.metrics(self.meter.get_average_metrics_as_dict, arg_list=[mesh_or_contour], kwargs_dict={'prefix': 'm_corners_'})
            
            ## gradients and optimizer step
            loss_log_dict = opt_step(opt, epoch, self.model, self.compute_losses, z, z_corners, batch, self.auto_clip)
            cur_log_dict.update(loss_log_dict)
            cur_log_dict.update(self.sub_grad_norm_dict)
            cur_log_dict.update(self.lambda_dict)
            if set_and_true('use_scheduler', self.config):
                sched.step()

            if ('adaptive_penalty_updates' in self.config) and (self.config['adaptive_penalty_updates']):
                self.lambda_dict, self.aug_lagr_lambda_dict, self.nu_dict = adaptive_penalty_update(self.lambda_dict, self.aug_lagr_lambda_dict, self.nu_dict, loss_log_dict, self.config)

            
            #with torch.no_grad():
            #    for lambda_key, lambda_value in self.lambda_dict.items():
            #        sub_loss = cur_log_dict[lambda_key.replace('lambda', 'loss_unweighted')]
            #        self.lambda_vec_dict[lambda_key] = self.lambda_vec_dict[lambda_key] + sub_loss * lambda_value
            ## Async Logging
            cur_log_dict.update({
                'neg_loss_div': (-1) * loss_log_dict['loss_div'] if 'loss_div' in loss_log_dict else 0.0,
                'grad_norm_post_clip': self.auto_clip.get_last_gradient_norm(),
                'grad_clip': self.auto_clip.get_clip_value(),
                'lr': self.config['lr'] if sched is None else sched.get_last_lr(),
                'epoch': epoch,
                })
            self.log_history_dict[epoch] = cur_log_dict
            self.log_to_wandb(epoch)
            save_model_every_n_epochs(self.model, opt, sched, self.config, epoch)
            # pbar.set_description(f"{epoch}: env:{loss_env.item():.1e} BC:{loss_if.item():.1e} obst:{loss_obst.item():.1e} eik:{loss_eikonal.item():.1e} cc:{loss_scc.item():.1e} div:{loss_div.item():.1e} curv:{loss_curv.item():.1e}")
            pbar.set_description(f"{epoch}:" + " ".join([f"{k}:{v.item():.1e}" for k, v in loss_log_dict.items()]))
        
        ## Final logging
        self.log_to_wandb(epoch, await_all=True)
        ## Finished
        self.timer_helper.print_logbook()

    def _get_data_batch(self):
        ginn_bsize = self.config['data_bsize'] if self.config['data_bsize'] > 0 else len(self.data)
        start_idx = self.cur_data_idx
        end_idx = min(start_idx + ginn_bsize, len(self.data))
        
        perm_idcs = self.shuffle_idcs[start_idx:end_idx]
        pts = self.data.pt_coords[perm_idcs]
        sdf = self.data.sdf_vals[perm_idcs]
        idcs = self.data.idcs[perm_idcs]
        
        if self.cur_data_idx + ginn_bsize > len(self.data):
            start_idx = 0
            end_idx = ginn_bsize - (len(self.data) - self.cur_data_idx)
            perm_idcs = self.shuffle_idcs[start_idx:end_idx]
            pts = torch.cat([pts, self.data.pt_coords[perm_idcs]])
            sdf = torch.cat([sdf, self.data.sdf_vals[perm_idcs]])
            idcs = torch.cat([idcs, self.data.idcs[perm_idcs]])
            self.shuffle_idcs = torch.randperm(len(self.data), device=self.device)
            
        self.cur_data_idx = end_idx
        return pts, sdf, idcs
        
        # if not hasattr(self, 'data_iter'):
        #     self.data_iter = iter(self.dataloader)
        
        # try:
        #     return next(self.data_iter)
        # except StopIteration:
        #     self.data_iter = iter(self.dataloader)
        #     return next(self.data_iter)

    def compute_losses(self, z, epoch, batch=None, z_corners=None):

       # compute surface points
        if (self.config['lambda_div'] > 0 or self.config['lambda_curv'] > 0) and (epoch >= self.config['surf_pts_warmup_n_epochs']):
            if self.p_surface is None or epoch % self.config['surf_pts_recompute_every_n_epochs'] == 0:
                self.p_surface, self.weights_surf_pts = self.shape_boundary_helper.get_surface_pts(z)

        sub_loss_dict = {}
        sub_loss_unweighted_dict = {}
        sub_al_loss_dict = {}
        al_vec_l2_dict = {}

        obj_loss, obj_loss_unweighted = self.objective_loss[1](z=z, lambda_vec=None, epoch=epoch, batch=batch)
        sub_loss_dict['loss_'+self.config['objective']] = self.objective_loss[0] * obj_loss #lambda naming used for backwards compatibility to rest of code
        sub_loss_unweighted_dict['loss_unweighted_'+self.config['objective']] = obj_loss_unweighted

        for lambda_key, lambda_tuple in self.aug_lagr_lambda_dict.items():

            lambda_value, lambda_vec = lambda_tuple
            mu_value = self.lambda_dict[lambda_key]
            if lambda_vec is not None:
                lambda_vec = lambda_value * lambda_vec

            sub_loss, sub_al_loss = self.loss_dispatcher[lambda_key](z=z, lambda_vec=lambda_vec, epoch=epoch, batch=batch)

            sub_loss_unweighted_dict[lambda_key.replace("lambda", "loss_unweighted")] = sub_loss
            if lambda_key == 'lambda_curv': #WARNING: hack to scale curvature loss
                sub_loss = sub_loss * self.config['curv_weight']
        
            penalty_term = sub_loss
            lambda_term = torch.sqrt(sub_loss)
            sub_al_loss_dict[lambda_key.replace("lambda", "lagrangian")] = lambda_value * lambda_term

            
            if self.config['use_augmented_lagrangian']:
                al_vec_l2_dict[lambda_key.replace("lambda", "l2")] = torch.norm(lambda_value)
                al_vec_l2_dict[lambda_key.replace("lambda", "mu")] = mu_value
                
                #sub_loss_dict[lambda_key.replace("lambda", "loss")] = 0.5 * mu_value * sub_loss.pow(2) + lambda_value * sub_loss
                sub_loss_dict[lambda_key.replace("lambda", "loss")] = 0.5 * mu_value * penalty_term + lambda_value * lambda_term

            else: 
                sub_loss_dict[lambda_key.replace("lambda", "loss")] = lambda_value * penalty_term

        if ('weight_rescale_on' in self.config) and (self.config['weight_rescale_on']) and (epoch>0):
            raise NotImplementedError('Weight rescaling not implemented yet with new loss structure')
            if ("weight_rescale_interval" in self.config) and (epoch%self.config['weight_rescale_interval'] == 0):
                self.sub_grad_norm_dict = grad_norm_sub_losses(self.model, sub_loss_unweighted_dict)
                self.lambda_dict = lambda_balancing(self.lambda_dict, self.sub_grad_norm_dict, self.config['weight_rescale_alpha'])


        loss = sum(sub_loss_dict.values())
        return loss, sub_loss_dict, sub_loss_unweighted_dict, sub_al_loss_dict, al_vec_l2_dict
    
    def log_to_wandb(self, epoch, await_all=False):
        with self.timer_helper.record('plot_helper.pool await async results'):
            ## Wait for plots to be ready, then log them
            while self.cur_plot_epoch <= epoch:
                if not self.mpm.are_plots_available_for_epoch(self.cur_plot_epoch):
                    ## no plots for this epoch; just log the current losses
                    wandb.log(self.log_history_dict[self.cur_plot_epoch])
                    del self.log_history_dict[self.cur_plot_epoch]
                    self.cur_plot_epoch += 1
                elif self.mpm.plots_ready_for_epoch(self.cur_plot_epoch):
                    ## plots are available and ready
                    wandb.log(self.log_history_dict[self.cur_plot_epoch] | self.mpm.pop_results_dict(self.cur_plot_epoch))
                    del self.log_history_dict[self.cur_plot_epoch]
                    self.cur_plot_epoch += 1
                elif await_all:
                    ## plots are not ready yet - wait for them
                    self.logger.debug(f'Waiting for plots for epoch {self.cur_plot_epoch}')
                    time.sleep(1)
                else:
                    # print(f'Waiting for plots for epoch {cur_plot_epoch}')
                    break

    def init_loss_dispatcher(self):
        self.loss_dispatcher = {
            # ginn losses
            'lambda_curv': self.loss_curv,
            'lambda_div': self.loss_div,
            'lambda_eikonal': self.loss_eikonal,
            'lambda_obst': self.loss_obst,
            'lambda_if_normal': self.loss_if_normal,
            'lambda_if': self.loss_if,
            'lambda_env': self.loss_env,
            'lambda_scc': self.loss_scc,
            # data losses
            'lambda_lip': self.loss_lip,
            'lambda_dirichlet': self.loss_dirichlet,
            'lambda_data': self.loss_data,
        }


    def init_objective_loss(self):

        assert self.loss_dispatcher is not None, 'Loss dispatcher must be initialized before objective loss'
        assert self.aug_lagr_lambda_dict is not None, 'Augmented lagrangian lambdas must be initialized before objective loss'
        max_objective = 'max_'+self.config['objective']
        assert not (max_objective in self.config and abs(self.config[max_objective]) > 0.), f'Max value for objective loss not allowed in config'
        lambda_objective = 'lambda_'+self.config['objective']
        assert (lambda_objective in self.config) and (self.config[lambda_objective] > 0), f'Weighting for objective loss not found in config or not positive'
        self.objective_loss = (self.config[lambda_objective] * torch.tensor(1.0, device=self.config['device']), 
                            self.loss_dispatcher[lambda_objective])
        self.loss_dispatcher.pop(lambda_objective)
        self.aug_lagr_lambda_dict.pop(lambda_objective)
        self.lambda_dict.pop(lambda_objective)
        self.logger.info(f'Using objective loss: {self.config["objective"]}')


    def init_aug_lagr_lambas(self):
        lambda_n_points_config_dict = {
            'lambda_curv': 'surf_pts_nof_points',
            'lambda_div': 'surf_pts_nof_points',
            'lambda_eikonal': 'n_points_domain',
            'lambda_obst': 'n_points_obstacles',
            'lambda_if_normal': 'n_points_interfaces',
            'lambda_if': 'n_points_interfaces',
            'lambda_env': 'n_points_envelope',
            'lambda_scc': '',
            'lambda_dirichlet': '',
        }


        self.aug_lagr_lambda_dict = {}

        for lamb in self.lambda_dict.keys():
            if self.config['use_augmented_lagrangian']:
                if lamb == 'lambda_scc' or lamb=='lambda_dirichlet':
                    n_points_config = 1
                elif lamb == 'lambda_if_normal':
                    n_points_config = self.config[lambda_n_points_config_dict[lamb]] * self.config['nx'] * self.config['ginn_bsize']
                else:
                    n_points_config = self.config[lambda_n_points_config_dict[lamb]] * self.config['ginn_bsize']
                
                #self.lambda_vec_dict[lamb] = torch.zeros(n_points_config, device=self.config['device'], requires_grad=True)
                self.aug_lagr_lambda_dict[lamb] = [torch.tensor(1.0, device=self.config['device']),
                                              None] #lambda vectors no longer used in augmented lagrangian, but kept for compatibility
            else:
                self.aug_lagr_lambda_dict[lamb] = None

    
    def loss_eikonal(self, z, lambda_vec, **kwargs):
        loss_eikonal = torch.tensor(0.0)
        xs_domain = self.p_sampler.sample_from_domain()
            ## Eikonal loss: NN should have gradient norm 1 everywhere
        y_x_eikonal = self.netp.vf_x(*tensor_product_xz(xs_domain, z))
        loss_eikonal, loss_al_eikonal = eikonal_loss(y_x_eikonal, lambda_vec=lambda_vec)
        return loss_eikonal, loss_al_eikonal
    
    def loss_obst(self, z, lambda_vec, **kwargs):
        loss_obst = torch.tensor(0.0)
        ys_obst = self.model(*tensor_product_xz(self.p_sampler.sample_from_obstacles(), z))
        loss_obst, loss_al_obst = obstacle_interior_loss(ys_obst, lambda_vec=lambda_vec)
        return loss_obst, loss_al_obst
    
    def loss_env(self, z, lambda_vec, **kwargs):
        loss_env = torch.tensor(0.0)
        ys_env = self.model(*tensor_product_xz(self.p_sampler.sample_from_envelope(), z)).squeeze(1)
        loss_env, loss_al_env = envelope_loss(ys_env, lambda_vec=lambda_vec)
        return loss_env, loss_al_env

    def loss_if(self, z, lambda_vec, **kwargs):
        ys_BC = self.model(*tensor_product_xz(self.p_sampler.sample_from_interface()[0], z)).squeeze(1)
        loss_if, loss_al_if = interface_loss(ys_BC, lambda_vec=lambda_vec)
        return loss_if, loss_al_if
    
    def loss_if_normal(self, z, lambda_vec, **kwargs):
        pts_normal, target_normal = self.p_sampler.sample_from_interface()
        ys_normal = self.netp.vf_x(*tensor_product_xz(pts_normal, z)).squeeze(1)
        loss_if_normal, loss_al_if_normal = normal_loss_euclidean(ys_normal, 
                                                                  target_normal=torch.cat([target_normal for _ in range(self.config['ginn_bsize'])]), 
                                                                  lambda_vec=lambda_vec)
        return loss_if_normal, loss_al_if_normal
    
    def loss_scc(self, z, lambda_vec, epoch, **kwargs):
        loss_scc = torch.tensor(0.0)
        loss_sub0 = torch.tensor(0.0)
        loss_super0 = torch.tensor(0.0)
        success, loss_scc, loss_super0, loss_sub0 = self.ph_manager.calc_ph_loss_cripser(z, epoch)
        return loss_scc, torch.tensor(0.0, device=self.config['device'])
    

    #TODO implement augmented lagrangian (_al) terms for losses below 
    def loss_data(self, batch, **kwargs):
        loss_data = torch.tensor(0.0)
        x, y, idcs = batch
        # x, y, idcs = x.to(self.device), y.to(self.device), idcs.to(self.device)
        z_data = self.z_corners[idcs]
        y_pred = self.model(x, z_data).squeeze(1)
        loss_data = mse_loss(y_pred, y)

        return loss_data, torch.tensor(0.0)
    
    def loss_dirichlet(self, z, batch, **kwargs):
        loss_dirichlet =  torch.tensor(0.0)
        if self.config['lambda_dirichlet'] > 0:
            ## TODO: maybe implement this surface point sampling
            # if config['dirichlet_use_surface_points']:      
            x, y, _ = batch
            y_pred = self.model(*tensor_product_xz(x, z)).squeeze(1)
            loss_dirichlet = l1_loss(y_pred, y)
        return loss_dirichlet, torch.tensor(0.0)
    
    def loss_lip(self, **kwargs):
        loss_lip = torch.tensor(0.0)
        if self.config['lambda_lip'] > 0:
            loss_lip = self.model.get_lipschitz_loss()

        return loss_lip, torch.tensor(0.0)
    
    # DIV LOSS
    
    def loss_div(self, z, lambda_vec, **kwargs):
        loss_div = torch.tensor(0.0)
        loss_al_div = torch.tensor(0.0)
        if self.p_surface is None:
            self.logger.info('No surface points found - skipping diversity loss')
        else:
            y_div = self.model(*tensor_product_xz(self.p_surface.data, z)).squeeze(1)  # [(bz k)] whereas k is n_surface_points; evaluate model at all surface points for each shape
            loss_div, loss_al_div = closest_shape_diversity_loss(einops.rearrange(y_div, '(bz k)-> bz k', bz=self.config['ginn_bsize']), 
                                                                     lambda_vec=lambda_vec,
                                                                     weights=self.weights_surf_pts,
                                                                     norm_order=self.config['div_norm_order'],
                                                                     neighbor_agg_fn=self.config['div_neighbor_agg_fn'])
            if torch.isnan(loss_div) or torch.isinf(loss_div):
                self.logger.warning(f'NaN or Inf loss_div: {loss_div}')
                loss_div = torch.tensor(0.0)
                loss_al_div = torch.tensor(0.0)
        
        loss_div = torch.clamp(loss_div - self.config['max_div'], min=0)
        return loss_div, loss_al_div
    
    # CURVATURE LOSS
    
    def loss_curv(self, z, lambda_vec, **kwargs):
        loss_curv = torch.tensor(0.0)
        loss_curv_unweighted = torch.tensor(0.0)
        if self.p_surface is None:
            self.logger.debug('No surface points found - skipping curvature loss')
        else:
            # check this here, as for vmap-ed curvature it can't be checked there
            assert self.weights_surf_pts is None or torch.allclose(self.weights_surf_pts.sum(), torch.tensor(1.0)), f"weights must sum to 1"
            
            weights = self.weights_surf_pts
            if 'softmax' in self.config['curvature_expression']:
                weights = torch.ones(self.weights_surf_pts.shape, device=self.config['device'])
            if set_and_true('curvature_use_gradnorm_weights', self.config):
                # weights = self.k_theta_gradnorm_fn(self.netp.params_, self.p_surface.data, self.p_surface.z_in(z), torch.ones(size=(len(self.p_surface.data), 1), device=self.config['device']) / len(self.p_surface.data)) # gradnorm - unsqueeze is needed for vmap
                gn = self.k_theta_gradnorm_fn(self.netp.params_, self.p_surface.data, self.p_surface.z_in(z), torch.ones(size=(len(self.p_surface.data), 1), device=self.config['device']) / len(self.p_surface.data)) #, self.weights_surf_pts.unsqueeze(1)) 
                weights = gn.sum() / gn
                weights = weights / weights.sum()
            
            y_x_surf = self.netp.vf_x(self.p_surface.data, self.p_surface.z_in(z)).squeeze(1)
            y_xx_surf = self.netp.vf_xx(self.p_surface.data, self.p_surface.z_in(z)).squeeze(1)
            loss_curv, loss_curv_unweighted = expression_curvature_loss(y_x_surf, y_xx_surf, 
                                                  expression=self.config['curvature_expression'],
                                                  clip_max_value=self.config['strain_curvature_clip_max'],
                                                  weights=weights)
            
            loss_curv = max(torch.tensor(0.0, device=self.config['device']), loss_curv - self.config['max_curv'])

        if set_and_true('curvature_after_5000_epochs', self.config) and kwargs['epoch'] < 5000:
            loss_curv = torch.tensor(0.0, device=self.config['device'])
            loss_curv_unweighted = torch.tensor(0.0, device=self.config['device'])
    
        return loss_curv, loss_curv_unweighted
    
    def get_k_theta_gradnorm_func(self):
        
        # non-vectorized loss function
        def curvature_loss_wrapper(params, x, z, weights):
            # for netp calls, use the properties f_x_ and f_xx_ instead of the methods f_x and f_xx
            y_x = self.netp.f_x_(params, x, z).squeeze(1)
            y_xx = self.netp.f_xx_(params, x, z).squeeze(1)
            loss_curv, loss_curv_unweighted = expression_curvature_loss(y_x, y_xx, 
                                                expression=self.config['curvature_expression'],
                                                clip_max_value=self.config['strain_curvature_clip_max'],
                                                weights=weights)
            loss_curv = torch.clamp(loss_curv - self.config['max_curv'], min=0)
            return loss_curv

        # compute gradient wrt to first argument, which is theta
        k_theta = jacrev(curvature_loss_wrapper, argnums=0) # params, nx, nz, nx -> [ny, params]

        # vectorize
        vk_theta = vmap(k_theta, in_dims=(None, 0, 0, 0), out_dims=(0))  ## params, [bxz, nx], [bxz, nz] [bxz, nx] -> [bxz, params]
        
        def final_func(params_, x, z, weights):
            params = {key: tensor.detach() for key, tensor in params_.items()}
            res = vk_theta(params, x, z, weights)
            grads = torch.hstack([g.flatten(start_dim=1) for g in res.values()]) ## flatten batched grads per parameter
            # grads = [param.grad.detach().flatten() for param in params if param.grad is not None ]
            grad_norm = grads.norm(dim=1)
            return grad_norm
            
        return final_func