from typing import List, Sequence, Optional, Union, Dict, Tuple

import numpy as np
import torch
import torch.nn as nn
from models.density_fields import KPlaneDensityField
from models.dynamic_field import D_field
from models.static_field import S_field

from models.lplane_field_hsv import Lplane_field
from ops.activations import init_density_activation
from raymarching.ray_samplers import (
    UniformLinDispPiecewiseSampler, UniformSampler,
    ProposalNetworkSampler, RayBundle, RaySamples
)
from raymarching.spatial_distortions import SceneContraction, SpatialDistortion
from utils.timer import CudaTimer
import pdb

class LowrankModel(nn.Module):
    def __init__(self,
                 hues: torch.Tensor,
                 # boolean flags
                 is_ndc: bool,
                 is_contracted: bool,#?
                 aabb: torch.Tensor,
                 # Model arguments
                 multiscale_res: Sequence[int],
                 use_motion:bool = False,
                 solid_threshold: float = 0.0,
                 # Spatial distortion
                 global_translation: Optional[torch.Tensor] = None,
                 global_scale: Optional[torch.Tensor] = None,
                 # proposal-sampling arguments
                 num_proposal_iterations: int = 1,
                 use_same_proposal_network: bool = False,
                 proposal_net_args_list: List[Dict] = None,
                 num_proposal_samples: Optional[Tuple[int]] = None,
                 num_samples: Optional[int] = None,
                 single_jitter: bool = False,
                 proposal_warmup: int = 5000,
                 proposal_update_every: int = 5,
                 use_proposal_weight_anneal: bool = True,
                 proposal_weights_anneal_max_num_iters: int = 1000,
                 proposal_weights_anneal_slope: float = 10.0,
                 num_frames: int = None,
                 # lightning field
                 L_concat_features_across_scales: bool = False,
                 L_linear_decoder_layers: Optional[int] = 1,
                 L_grid_config: Union[str, List[Dict]] = None,
                 L_basis_color_dim : int = 32,
                 L_apperance_embdding_dim: int = 32,
                 L_decoder_layer: int = 5,
                 k: int = 5,
                 # static network
                 S_density_activation: Optional[str] = 'trunc_exp',
                 S_concat_features_across_scales: bool = False,
                 S_linear_decoder: bool = True,
                 S_linear_decoder_layers: Optional[int] = 1,
                 S_grid_config: Union[str, List[Dict]] = None,
                 S_basis_color_dim : int = 16,
                 S_decoder_layer: int = 5,
                 # dynamic network
                 D_density_activation: Optional[str] = 'trunc_exp',
                 D_concat_features_across_scales: bool = False,
                 D_linear_decoder: bool = True,
                 D_linear_decoder_layers: Optional[int] = 1,

                 D_grid_config: Union[str, List[Dict]] = None,
                 D_basis_color_dim : int = 4,
                 D_decoder_layer: int = 5,

                 **kwargs,
                 
                 ):
        super().__init__()
        self.multiscale_res = multiscale_res
        self.is_ndc = is_ndc
        self.use_motion = use_motion
        if hues != None:
            self.hues =hues
        else:
            self.hues = None
        self.is_contracted = is_contracted
        #for lightning network
        self.L_concat_features_across_scales = L_concat_features_across_scales
        self.L_linear_decoder_layers = L_linear_decoder_layers
        self.L_grid_config = L_grid_config
        self.L_basis_color_dim = L_basis_color_dim
        self.L_apperance_embdding_dim = L_apperance_embdding_dim
        self.L_decoder_layer = L_decoder_layer
        #for static network
        self.S_concat_features_across_scales = S_concat_features_across_scales
        self.S_linear_decoder = S_linear_decoder
        self.S_linear_decoder_layers = S_linear_decoder_layers
        self.S_density_act = init_density_activation(S_density_activation)
        self.S_grid_config = S_grid_config
        self.S_basis_color_dim = S_basis_color_dim
        self.S_decoder_layer = S_decoder_layer
        #for dynamic network
        self.D_concat_features_across_scales = D_concat_features_across_scales
        self.D_linear_decoder = D_linear_decoder
        self.D_linear_decoder_layers = D_linear_decoder_layers
        self.D_density_act = init_density_activation(D_density_activation)
        self.D_grid_config = D_grid_config
        self.D_basis_color_dim = D_basis_color_dim
        self.D_decoder_layer = D_decoder_layer

        self.timer = CudaTimer(enabled=False)

        self.spatial_distortion: Optional[SpatialDistortion] = None
        if self.is_contracted:
            self.spatial_distortion = SceneContraction(
                order=float('inf'), global_scale=global_scale,
                global_translation=global_translation)
        self.d_field = D_field(
            aabb,
            grid_config=self.D_grid_config,
            concat_features_across_scales=self.D_concat_features_across_scales,
            multiscale_res=self.multiscale_res,
            spatial_distortion=self.spatial_distortion,
            density_activation=self.D_density_act,
            linear_decoder=self.D_linear_decoder,
            linear_decoder_layers=self.D_linear_decoder_layers,
            num_frames = num_frames,
            use_motion =  use_motion,
            solid_threshold = solid_threshold,
        )
        self.s_field = S_field(
            aabb,
            grid_config=self.S_grid_config,
            concat_features_across_scales=self.S_concat_features_across_scales,
            multiscale_res=self.multiscale_res,
            spatial_distortion=self.spatial_distortion,
            density_activation=self.S_density_act,
            linear_decoder_layers=self.S_linear_decoder_layers,
            linear_decoder=self.S_linear_decoder,
        )
        self.l_field = Lplane_field(
            aabb,
            grid_config=self.L_grid_config,
            concat_features_across_scales=self.D_concat_features_across_scales,
            multiscale_res=self.multiscale_res,
            spatial_distortion=self.spatial_distortion,
            density_activation=self.D_density_act,
            linear_decoder=self.D_linear_decoder,
            linear_decoder_layers=self.D_linear_decoder_layers,
            k = k,
            hues = self.hues,
            num_frames = num_frames,
            use_motion =  use_motion,
            solid_threshold = solid_threshold,
        )

        # Initialize proposal-sampling nets
        self.density_fns = []
        self.num_proposal_iterations = num_proposal_iterations
        self.proposal_net_args_list = proposal_net_args_list
        self.proposal_warmup = proposal_warmup
        self.proposal_update_every = proposal_update_every
        self.use_proposal_weight_anneal = use_proposal_weight_anneal
        self.proposal_weights_anneal_max_num_iters = proposal_weights_anneal_max_num_iters
        self.proposal_weights_anneal_slope = proposal_weights_anneal_slope
        self.proposal_networks = torch.nn.ModuleList()
        if use_same_proposal_network:
            assert len(self.proposal_net_args_list) == 1, "Only one proposal network is allowed."
            prop_net_args = self.proposal_net_args_list[0]
            network = KPlaneDensityField(
                aabb, spatial_distortion=self.spatial_distortion,
                density_activation=self.D_density_act, linear_decoder=self.D_linear_decoder, **prop_net_args)
            self.proposal_networks.append(network)
            self.density_fns.extend([network.get_density for _ in range(self.num_proposal_iterations)])
        else:#use different proposal network...
            for i in range(self.num_proposal_iterations):
                prop_net_args = self.proposal_net_args_list[min(i, len(self.proposal_net_args_list) - 1)]
                network = KPlaneDensityField(
                    aabb, spatial_distortion=self.spatial_distortion,
                    density_activation=self.D_density_act, linear_decoder=self.D_linear_decoder, **prop_net_args,
                )
                self.proposal_networks.append(network)
            self.density_fns.extend([network.get_density for network in self.proposal_networks])

        update_schedule = lambda step: np.clip(
            np.interp(step, [0, self.proposal_warmup], [0, self.proposal_update_every]),
            1,
            self.proposal_update_every,
        )
        if self.is_contracted or self.is_ndc:
            initial_sampler = UniformLinDispPiecewiseSampler(single_jitter=single_jitter)
        else:
            initial_sampler = UniformSampler(single_jitter=single_jitter)
        self.proposal_sampler = ProposalNetworkSampler(
            num_nerf_samples_per_ray=num_samples,
            num_proposal_samples_per_ray=num_proposal_samples,
            num_proposal_network_iterations=self.num_proposal_iterations,
            single_jitter=single_jitter,
            update_sched=update_schedule,
            initial_sampler=initial_sampler
        )

    def step_before_iter(self, step):
        if self.use_proposal_weight_anneal:
            # anneal the weights of the proposal network before doing PDF sampling
            N = self.proposal_weights_anneal_max_num_iters
            # https://arxiv.org/pdf/2111.12077.pdf eq. 18
            train_frac = np.clip(step / N, 0, 1)
            bias = lambda x, b: (b * x) / ((b - 1) * x + 1)
            anneal = bias(train_frac, self.proposal_weights_anneal_slope)
            self.proposal_sampler.set_anneal(anneal)

    def step_after_iter(self, step):
        if self.use_proposal_weight_anneal:
            self.proposal_sampler.step_cb(step)

    @staticmethod
    def render_rgb(rgb: torch.Tensor, weights: torch.Tensor, bg_color: Optional[torch.Tensor]):
        comp_rgb = torch.sum(weights * rgb, dim=-2)
        accumulated_weight = torch.sum(weights, dim=-2)

        if bg_color is None:
            pass
        else:
            comp_rgb = comp_rgb + (1.0 - accumulated_weight) * bg_color
        return comp_rgb

    @staticmethod
    def render_depth(weights: torch.Tensor, ray_samples: RaySamples, rays_d: torch.Tensor):
        steps = (ray_samples.starts + ray_samples.ends) / 2
        one_minus_transmittance = torch.sum(weights, dim=-2)
        depth = torch.sum(weights * steps, dim=-2) + one_minus_transmittance * rays_d[..., -1:]
        return depth

    @staticmethod
    def render_accumulation(weights: torch.Tensor):
        accumulation = torch.sum(weights, dim=-2)
        return accumulation
    
    def forward(self, rays_o, rays_d, bg_color, near_far: torch.Tensor, timestamps=None):
        """
        rays_o : [batch, 3]
        rays_d : [batch, 3]
        timestamps : [batch]
        near_far : [batch, 2]
        """

        # Fix shape for near-far
        nears, fars = torch.split(near_far, [1, 1], dim=-1)
        if nears.shape[0] != rays_o.shape[0]:
            ones = torch.ones_like(rays_o[..., 0:1])
            nears = ones * nears
            fars = ones * fars

        ray_bundle = RayBundle(origins=rays_o, directions=rays_d, nears=nears, fars=fars)
        # Note: proposal sampler mustn't use timestamps (=camera-IDs) with appearance-embedding,
        #       since the appearance embedding should not affect density. We still pass them in the
        #       call below, but they will not be used as long as density-field resolutions
        #       are be 3D.
        ray_samples, weights_list, ray_samples_list = self.proposal_sampler.generate_ray_samples(
            ray_bundle, density_fns=self.density_fns)
        static_field_out = self.s_field(ray_samples.get_positions(), ray_bundle.directions)
        dynamic_field_out = self.d_field(ray_samples.get_positions(), ray_bundle.directions, timestamps)
        light_field_out = self.l_field(ray_samples.get_positions(), ray_bundle.directions, timestamps)
        static_rgb, static_density = static_field_out["rgb"], static_field_out["density"]
        dynamic_rgb, dynamic_density = dynamic_field_out["rgb"], dynamic_field_out["density"]
        l_color,  il_density, il_hues = light_field_out['l_color'], light_field_out['density_il'], light_field_out['hues'] 
        
        

        s_weights, d_weights, l_weights, r_s_weight, l_s_weight, d_s_weight, total_weights = ray_samples.get_weights_decomp(static_density,dynamic_density, il_density)#get_weights(density)
        weights_list.append(r_s_weight )
        weights_list.append(d_s_weight )
        weights_list.append(l_s_weight)

        weights_list.append(total_weights)
        dynamic_prob = (d_weights)/(total_weights + 1e-10)
        static_prob = s_weights/(total_weights + 1e-10)
        il_prob = l_weights/(total_weights + 1e-10)

        
        dynamic_class = self.render_rgb(rgb=dynamic_prob, weights=total_weights, bg_color=None)
        static_class = self.render_rgb(rgb=static_prob, weights=total_weights, bg_color=None)
        il_class = self.render_rgb(rgb=il_prob, weights=total_weights, bg_color=None)
        classes = torch.argmax(torch.concat([static_class,dynamic_class,il_class],dim=-1),dim=-1,keepdim=True)
        
        
        
        
        ones = torch.ones(dynamic_class.shape,dtype = dynamic_class.dtype,device=dynamic_class.device)# 4096 1
        ones2 = torch.ones(dynamic_prob.shape,dtype = dynamic_prob.dtype,device=dynamic_prob.device)# 4096 48 1

        
        only_illum = self.render_rgb(rgb = l_color, weights = l_s_weight, bg_color=None)

        ilumination =  self.render_rgb(rgb = l_color, weights = l_weights, bg_color=None)#.nan_to_num(1e-6)


        ray_samples_list.append(ray_samples)
        liv_s_rgb_origin =  self.render_rgb(rgb=static_rgb, weights=s_weights, bg_color=None)#.nan_to_num(1e-6)

        liv_s_rgb = liv_s_rgb_origin+ ilumination
        
        liv_d_rgb = self.render_rgb(rgb=dynamic_rgb, weights=d_weights, bg_color=None)  #.nan_to_num(1e-6)
        
        wrong_weight_s = torch.sum(s_weights ,dim=-2) # 4096* 48 *1
        wrong_weight_d = torch.sum(d_weights,dim=-2)
        
        liv_rgb = liv_s_rgb + liv_d_rgb 

        reh_rgb = self.render_rgb(rgb=static_rgb , weights=r_s_weight, bg_color=None)#.nan_to_num(1e-6)
        
        diff_rgb =   liv_s_rgb - reh_rgb


        depth = self.render_depth(weights=total_weights, ray_samples=ray_samples, rays_d=ray_bundle.directions)
        if not torch.isfinite(reh_rgb).all():
            print("reh_rgb has inf value")
            pdb.set_trace()
        if not torch.isfinite(wrong_weight_s).all():
            print("wrong_weight_s has inf value")
            pdb.set_trace()
        if not torch.isfinite(wrong_weight_d).all():
            print("wrong_weight_d has inf value")
            pdb.set_trace()
        if not torch.isfinite(liv_rgb).all():
            print("liv_rgb has inf value")
            pdb.set_trace()

        outputs = {
            "reh_rgb": reh_rgb,
            "il_rgb": only_illum,
            "diff_rgb": diff_rgb,

            
            'dynamic_class':dynamic_class,#.nan_to_num(1e-6),
            'static_class':static_class,#.nan_to_num(1e-6),
            'il_class':il_class,#.nan_to_num(1e-6),
            'classes':classes,#.nan_to_num(1e-6),
            'wrong_weight_s':wrong_weight_s,#.nan_to_num(1e-6),
            'wrong_weight_d':wrong_weight_d,#.nan_to_num(1e-6),
            'wrong_weight_d_o': torch.sum(d_weights,dim=-2),
            'wrong_weight_il': torch.sum(l_weights,dim=-2),

            #'il_ratio':il_ratio,
            'hues': il_hues,
            
            'weight_s':torch.sum(s_weights,dim=-2),
            'weight_d':torch.sum(d_weights,dim=-2),
            'weight_l':torch.sum(l_weights,dim=-2),
            
            
            "liv_rgb": liv_rgb,
            "liv_s_rgb": liv_s_rgb,
            "liv_d_rgb": liv_d_rgb ,
            "liv_l_rgb": ilumination,
            "depth": depth#.nan_to_num(1e-6),
        }
        # These use a lot of GPU memory, so we avoid storing them for eval.
        if self.training:
            outputs["weights_list"] = weights_list
            outputs["ray_samples_list"] = ray_samples_list
        for i in range(self.num_proposal_iterations):
            outputs[f"prop_depth_{i}"] = self.render_depth(
                weights=weights_list[i], ray_samples=ray_samples_list[i], rays_d=ray_bundle.directions)
        return outputs
    def change_hue(self,rays_o,rays_d, near_far, time_human, hue_time):
        nears, fars = torch.split(near_far, [1, 1], dim=-1)
        if nears.shape[0] != rays_o.shape[0]:
            ones = torch.ones_like(rays_o[..., 0:1])
            nears = ones * nears
            fars = ones * fars
        #       are be 3D.
        ray_bundle = RayBundle(origins=rays_o, directions=rays_d, nears=nears, fars=fars)

        ray_samples, weights_list, ray_samples_list = self.proposal_sampler.generate_ray_samples(
            ray_bundle, density_fns=self.density_fns)
        static_field_out = self.s_field(ray_samples.get_positions(), ray_bundle.directions)
        dynamic_field_out = self.d_field(ray_samples.get_positions(), ray_bundle.directions, time_human)
        light_field_out = self.l_field.decouple(ray_samples.get_positions(), ray_bundle.directions, hue_time,time_human)
        static_rgb, static_density = static_field_out["rgb"], static_field_out["density"]
        dynamic_rgb, dynamic_density = dynamic_field_out["rgb"], dynamic_field_out["density"]
        l_color,  il_density, il_hues = light_field_out['l_color'], light_field_out['density_il'], light_field_out['hues'] 
        s_weights, d_weights, l_weights, r_s_weight, l_s_weight, d_s_weight, total_weights = ray_samples.get_weights_decomp(static_density,dynamic_density,il_density)#get_weights(density)

        ilumination =  self.render_rgb(rgb = l_color, weights = l_weights, bg_color=None)
        liv_s_rgb_origin =  self.render_rgb(rgb=static_rgb, weights=s_weights, bg_color=None)
        liv_s_rgb = liv_s_rgb_origin+ ilumination
        liv_d_rgb = self.render_rgb(rgb=dynamic_rgb, weights=d_weights, bg_color=None)
        liv_rgb = liv_s_rgb+ liv_d_rgb


        depth = self.render_depth(weights=total_weights, ray_samples=ray_samples, rays_d=ray_bundle.directions)

        outputs = {
            "liv_rgb": liv_rgb,
            "depth": depth
        }
        return outputs
    
    def custom_hue(self,rays_o,rays_d, near_far, time_human, hue):
        nears, fars = torch.split(near_far, [1, 1], dim=-1)
        if nears.shape[0] != rays_o.shape[0]:
            ones = torch.ones_like(rays_o[..., 0:1])
            nears = ones * nears
            fars = ones * fars
        #       are be 3D.
        ray_bundle = RayBundle(origins=rays_o, directions=rays_d, nears=nears, fars=fars)

        ray_samples, weights_list, ray_samples_list = self.proposal_sampler.generate_ray_samples(
            ray_bundle, density_fns=self.density_fns)
        static_field_out = self.s_field(ray_samples.get_positions(), ray_bundle.directions)
        dynamic_field_out = self.d_field(ray_samples.get_positions(), ray_bundle.directions, time_human)
        light_field_out = self.l_field.custom_hue(ray_samples.get_positions(), ray_bundle.directions, hue, time_human,time_human)
        static_rgb, static_density = static_field_out["rgb"], static_field_out["density"]
        dynamic_rgb, dynamic_density = dynamic_field_out["rgb"], dynamic_field_out["density"]
        l_color,  il_density, il_hues = light_field_out['l_color'], light_field_out['density_il'], light_field_out['hues'] 
        s_weights, d_weights, l_weights, r_s_weight, l_s_weight, d_s_weight, total_weights = ray_samples.get_weights_decomp(static_density,dynamic_density,il_density)

        ilumination =  self.render_rgb(rgb = l_color, weights = l_weights, bg_color=None)
        liv_s_rgb_origin =  self.render_rgb(rgb=static_rgb, weights=s_weights, bg_color=None)
        liv_s_rgb = liv_s_rgb_origin+ ilumination
        liv_d_rgb = self.render_rgb(rgb=dynamic_rgb, weights=d_weights, bg_color=None)
        liv_rgb = liv_s_rgb+ liv_d_rgb


        depth = self.render_depth(weights=total_weights, ray_samples=ray_samples, rays_d=ray_bundle.directions)

        outputs = {
            "liv_rgb": liv_rgb,
            "depth": depth
        }
        # These use a lot of GPU memory, so we avoid storing them for eval.
        return outputs
    def assemble_two_time(self,rays_o,rays_d, near_far, time_human, time_light):
        nears, fars = torch.split(near_far, [1, 1], dim=-1)
        if nears.shape[0] != rays_o.shape[0]:
            ones = torch.ones_like(rays_o[..., 0:1])
            nears = ones * nears
            fars = ones * fars
        #       are be 3D.
        ray_bundle = RayBundle(origins=rays_o, directions=rays_d, nears=nears, fars=fars)

        ray_samples, weights_list, ray_samples_list = self.proposal_sampler.generate_ray_samples(
            ray_bundle, density_fns=self.density_fns)
        static_field_out = self.s_field(ray_samples.get_positions(), ray_bundle.directions)
        dynamic_field_out = self.d_field(ray_samples.get_positions(), ray_bundle.directions, time_human)
        light_field_out = self.l_field.decouple(ray_samples.get_positions(), ray_bundle.directions,time_light,time_human)
        static_rgb, static_density = static_field_out["rgb"], static_field_out["density"]
        dynamic_rgb, dynamic_density = dynamic_field_out["rgb"], dynamic_field_out["density"]
        l_color,  il_density, il_hues = light_field_out['l_color'], light_field_out['density_il'], light_field_out['hues'] 
        s_weights, d_weights, l_weights, r_s_weight, l_s_weight, d_s_weight, total_weights = ray_samples.get_weights_decomp(static_density,dynamic_density,il_density)

        ilumination =  self.render_rgb(rgb = l_color, weights = l_weights, bg_color=None)#.nan_to_num(1e-6)
        liv_s_rgb_origin =  self.render_rgb(rgb=static_rgb, weights=s_weights, bg_color=None)#.nan_to_num(1e-6)
        liv_s_rgb = liv_s_rgb_origin+ ilumination
        liv_d_rgb = self.render_rgb(rgb=dynamic_rgb, weights=d_weights, bg_color=None)#.nan_to_num(1e-6)
        liv_rgb = liv_s_rgb+ liv_d_rgb


        depth = self.render_depth(weights=total_weights, ray_samples=ray_samples, rays_d=ray_bundle.directions)

        outputs = {
            "liv_rgb": liv_rgb,
            "depth": depth
        }
        # These use a lot of GPU memory, so we avoid storing them for eval.
        return outputs
    def no_light(self,rays_o,rays_d, near_far, time_human):
        nears, fars = torch.split(near_far, [1, 1], dim=-1)
        if nears.shape[0] != rays_o.shape[0]:
            ones = torch.ones_like(rays_o[..., 0:1])
            nears = ones * nears
            fars = ones * fars
        #       are be 3D.
        ray_bundle = RayBundle(origins=rays_o, directions=rays_d, nears=nears, fars=fars)

        ray_samples, weights_list, ray_samples_list = self.proposal_sampler.generate_ray_samples(
            ray_bundle, density_fns=self.density_fns)
        static_field_out = self.s_field(ray_samples.get_positions(), ray_bundle.directions)
        dynamic_field_out = self.d_field(ray_samples.get_positions(), ray_bundle.directions, time_human)
        light_field_out = self.l_field(ray_samples.get_positions(), ray_bundle.directions, time_human)

        static_rgb, static_density = static_field_out["rgb"], static_field_out["density"]
        dynamic_rgb, dynamic_density = dynamic_field_out["rgb"], dynamic_field_out["density"]
        l_color,  il_density, il_hues = light_field_out['l_color'], light_field_out['density_il'], light_field_out['hues'] 



        s_weights, d_weights,total_weights = ray_samples.get_weights_no_light(static_density,dynamic_density,il_density)#get_weights(density)
        liv_s_rgb_origin =  self.render_rgb(rgb=static_rgb, weights=s_weights, bg_color=None)#.nan_to_num(1e-6)
        liv_d_rgb = self.render_rgb(rgb=dynamic_rgb, weights=d_weights, bg_color=None)#.nan_to_num(1e-6)
        liv_rgb = liv_s_rgb_origin+ liv_d_rgb


        depth = self.render_depth(weights=s_weights+d_weights, ray_samples=ray_samples, rays_d=ray_bundle.directions)
        outputs = {
            "liv_rgb": liv_rgb,
            "depth": depth
        }
        # These use a lot of GPU memory, so we avoid storing them for eval.
        return outputs
    
    def get_params(self, lr: float):
        s_params = self.s_field.get_params()
        d_params = self.d_field.get_params()
        l_params = self.l_field.get_params()

        pn_params = [pn.get_params() for pn in self.proposal_networks]
        field_params = s_params["field"] + d_params["field"]  + l_params["field"] + [p for pnp in pn_params for p in pnp["field"]]
        nn_params = s_params["nn"] + d_params["nn"] + s_params["nn"]  + [p for pnp in pn_params for p in pnp["nn"]] 
        other_params = s_params["other"] + d_params["other"] + l_params["other"]  + [p for pnp in pn_params for p in pnp["other"]] # good. max 3
        hues = l_params["hues"]
        return [
            {"params": field_params, "lr": lr},
            {"params": nn_params, "lr": lr},
            {"params": other_params, "lr": lr},
            {"params": hues, "lr": lr}
        ]
