from typing import Union

from laplace.curvature import CurvatureInterface, BackPackGGN, CurvlinopsGGN
from laplace.curvature.asdl import AsdlGGN

from enum import Enum


class LLMFeatureType(Enum):
    LAST_TOKEN = 1
    FIRST_TOKEN = 2
    AVERAGE = 3

######################### Laplace Configs #########################

class LaplaceConfig:
    def __init__(
        self,
        batch_size: int = 100,
        lr: float = 0.001,
        lr_lora: float = 3e-4,
        wd: float = 0.1,
        grad_clip: float = 0.0,
        n_epochs: int = 1500, 
        head_n_epochs: int = 100,
        marglik_mode: str = "posthoc",
        val_frequency: int = 100,
        noise_var: float = 0.0001,
        subset_of_weights: str = "all",
        hess_factorization: str = "kfac",
        prior_prec_structure: str = "layerwise",
        posthoc_marglik_iters: int = 500,
        online_marglik_freq: int = 5,
        hessian_backend: type[CurvatureInterface] = AsdlGGN,
        last_layer_name: str = "base_model.model.head.modules_to_save.default.2",
        activation: str = "tanh",
        problem=None, 
        representation=None,
    ):
        self.activation = activation
        self.batch_size = batch_size
        self.lr = lr
        self.lr_lora = lr_lora
        self.wd = wd
        self.grad_clip = grad_clip
        self.n_epochs = n_epochs
        self.marglik_mode = marglik_mode
        self.noise_var = noise_var
        self.subset_of_weights = subset_of_weights
        self.hess_factorization = hess_factorization
        assert prior_prec_structure in ["scalar", "layerwise", "diagonal"]
        self.prior_prec_structure = prior_prec_structure
        self.posthoc_marglik_iters = posthoc_marglik_iters
        self.online_marglik_freq = online_marglik_freq
        self.hessian_backend = hessian_backend
        self.last_layer_name = last_layer_name
        self.head_n_epochs = head_n_epochs
        self.val_frequency = val_frequency

class LoraLaplaceConfig:
    def __init__(
        self,
        batch_size: int = 100,
        lr: float = 1e-3,
        lr_lora: float = 3e-4,
        wd: float = 5e-4,
        grad_clip: float = 0.0,
        n_epochs: int = 100, 
        head_n_epochs: int = 100,
        marglik_mode: str = "posthoc",
        val_frequency: int = 100,
        noise_var: float = 0.0001,
        subset_of_weights: str = "all",
        hess_factorization: str = "kfac",
        prior_prec_structure: str = "layerwise",
        posthoc_marglik_iters: int = 500,
        online_marglik_freq: int = 5,
        hessian_backend: type[CurvatureInterface] = AsdlGGN,
        last_layer_name: str = "base_model.model.head.modules_to_save.default.2",
        activation: str = "relu",
    ):
        self.activation = activation
        self.batch_size = batch_size
        self.lr = lr
        self.lr_lora = lr_lora
        self.wd = wd
        self.grad_clip = grad_clip
        self.n_epochs = n_epochs
        self.marglik_mode = marglik_mode
        self.noise_var = noise_var
        self.subset_of_weights = subset_of_weights
        self.hess_factorization = hess_factorization
        assert prior_prec_structure in ["scalar", "layerwise", "diagonal"]
        self.prior_prec_structure = prior_prec_structure
        self.posthoc_marglik_iters = posthoc_marglik_iters
        self.online_marglik_freq = online_marglik_freq
        self.hessian_backend = hessian_backend
        self.last_layer_name = last_layer_name
        self.head_n_epochs = head_n_epochs
        self.val_frequency = val_frequency

######################### FSP Laplace Configs #########################
class FSPLaplaceConfig:
    def __init__(
        self,
        prior_n_steps: int = 500, # 500
        prior_lr: float = 0.1,
        prior_val_frequency: int = 100,
        lr: float = 0.01, 
        n_epochs: int = 1500, 
        val_frequency: int = 100,
        early_stopping_patience: int = 1000,
        jitter: float = 1e-10, # 
        noise_var: float = 0.0001,
        batch_size: int = 100,
        max_rank: int = 500,
        n_chunks: int = 1,
        map_context_points: str = "bo_candidates", # uniform, uniform_bitstring, bo_candidates
        cov_context_points: str = "bo_candidates", # bo_candidates, sobol
        generator_seed: int = 0,
        context_points_batch_size: int = 100, #500,
        n_context_points_cov: int = 10000, #25000, #25000,
        params_sketch: str = "", # "ssrft"
        params_sketch_dim: int = 100000,
        activation: str = "relu",
        problem=None, 
        representation=None,
    ):
        self.activation = activation
        self.prior_n_steps = prior_n_steps
        self.prior_lr = prior_lr
        self.prior_val_frequency = prior_val_frequency
        self.lr = lr 
        self.n_epochs = n_epochs
        self.val_frequency = val_frequency
        self.early_stopping_patience = early_stopping_patience
        self.jitter = jitter
        self.noise_var = noise_var
        self.batch_size = batch_size
        self.max_rank = max_rank
        self.n_chunks = n_chunks
        self.map_context_points: str = map_context_points
        self.cov_context_points: str = cov_context_points
        self.generator_seed = generator_seed
        self.context_points_batch_size = context_points_batch_size
        self.n_context_points_cov = n_context_points_cov
        self.params_sketch = params_sketch
        self.params_sketch_dim = params_sketch_dim


class LoraFSPLaplaceConfig:
    def __init__(
        self,
        prior_n_steps: int = 500, # 500
        prior_lr: float = 0.1,
        prior_val_frequency: int = 100,
        lr: float = 0.001, 
        n_epochs: int = 100, 
        head_n_epochs: int = 100,
        lr_lora: float = 3e-4,
        grad_clip: float = 0.0,
        val_frequency: int = 100,
        early_stopping_patience: int = 1000,
        jitter: float = 1e-10, # 
        noise_var: float = 0.0001,
        batch_size: int = 100,
        max_rank: int = 50,
        n_chunks: int = 1,
        map_context_points: str = "bo_candidates", # uniform, uniform_bitstring, bo_candidates
        cov_context_points: str = "bo_candidates", # bo_candidates, sobol
        generator_seed: int = 0,
        context_points_batch_size: int = 100, #500,
        n_context_points_cov: int = 10000, #25000, #25000,
        params_sketch: str = "", # "ssrft"
        params_sketch_dim: int = 100000,
        activation: str = "relu",
        last_layer_name: str = "base_model.model.head.modules_to_save.default.2",
    ):
        self.activation = activation
        self.prior_n_steps = prior_n_steps
        self.prior_lr = prior_lr
        self.prior_val_frequency = prior_val_frequency
        self.lr = lr 
        self.n_epochs = n_epochs
        self.val_frequency = val_frequency
        self.early_stopping_patience = early_stopping_patience
        self.jitter = jitter
        self.noise_var = noise_var
        self.batch_size = batch_size
        self.max_rank = max_rank
        self.n_chunks = n_chunks
        self.map_context_points: str = map_context_points
        self.cov_context_points: str = cov_context_points
        self.generator_seed = generator_seed
        self.context_points_batch_size = context_points_batch_size
        self.n_context_points_cov = n_context_points_cov
        self.params_sketch = params_sketch
        self.params_sketch_dim = params_sketch_dim
        self.lr_lora = lr_lora
        self.grad_clip = grad_clip
        self.head_n_epochs = head_n_epochs
        self.last_layer_name = last_layer_name



