"""
Utility functions and classes for managing model parameters and logging.
"""
import json
import os
import argparse
import numpy as np

from gmpy2 import mpz

from module.model import RsetWrapper
from module.threshold_guess import ThresholdGuess

class NumpyEncoder(json.JSONEncoder):
    """ NumpyEncoder is a custom JSON encoder for handling numpy data types.

    Args:
        json.JSONEncoder: The base JSON encoder class.
    """
    def default(self, o):
        if isinstance(o, np.ndarray):
            return "ARRAY"
        return super().default(o)

### Class to manage model and experiment parameters
class Hparams:
    """
    Class to manage model and experiment parameters.
    """
    def __init__(self, args: argparse.Namespace=None):
        if args is None:
            return
        self.model_name = args.model_parser
        self.model_class = None
        self.model_params = None
        self.encoder = None
        self.encoder_params = None

        self.rs = args.random_state
        self.rng = np.random.default_rng(self.rs)

        self.retrain = args.retrain
        self.tune = args.tune
        self.selection = args.selection
        self.reset_results = args.reset_results

        self.io_params = {
            'output_dir': args.output_dir,
            'result_dir': args.result_dir,
            'model_dir': args.model_dir,
            'param_dir': args.param_dir,
        }
        os.makedirs(self.io_params['output_dir'], exist_ok=True)
        os.makedirs(f"{self.io_params['output_dir']}/{self.io_params['result_dir']}", exist_ok=True)
        os.makedirs(f"{self.io_params['output_dir']}/{self.io_params['model_dir']}", exist_ok=True)
        os.makedirs(f"{self.io_params['output_dir']}/{self.io_params['param_dir']}", exist_ok=True)


    def get_state(self):
        """ Get the current state of the Hparams instance. We only save relevant state.

        Returns:
            dict: A dictionary containing the current state of the Hparams instance.
        """
        return {
            'model_name': self.model_name,
            'model_class': self.model_class,
            'model_params': self.model_params,
            'encoder': self.encoder,
            'io_params': self.io_params,
            'rs': self.rs,
            'rng': self.rng,
        }

    def set_state(self, state):
        """ Set the state of the Hparams instance.

        Args:
            state (dict): A dictionary containing the state to set.
        """
        self.model_name = state['model_name']
        self.model_class = state['model_class']
        self.model_params = state['model_params']
        self.encoder = state['encoder']
        self.io_params = state['io_params']
        self.rs = state['rs']
        self.rng = state['rng']


    def rset_params_args_init(self, args):
        """ Initialize the parameters for the Rset model.

        Args:
            args (argparse.Namespace): The command line arguments.
        """
        self.model_class = RsetWrapper
        config = {
            "regularization": args.rset_lamb,
            "depth_budget": args.rset_depth_budget,
            "rashomon_bound_adder": args.rset_eps,
        }

        self.model_params = {
            'config': config,
        }

    def thres_guess_args_init(self, args) -> None:
        """ Initialize the parameters for the ThresholdGuess model.

        Args:
            args (argparse.Namespace): The command line arguments.
        """
        param = {
            'max_depth': args.enc_max_depth,
            'n_estimators': args.enc_n_estimators,
            'learning_rate': args.enc_lr,
        }
        self.encoder_params = {
            "guess_model_param": param,
            "back_select": True,
            "random_state": self.rs
        }
        self.encoder = ThresholdGuess(**self.encoder_params)

class BinSequence:
    def __init__(self, value=0, length=0):
        self.x = mpz(value)
        self.len = length
    
    def copy(self):
        return BinSequence(self.x, self.len)
    
    def append_0(self):
        self.len += 1
        return self
    
    def append_1(self):
        self.x = self.x.bit_set(self.len)
        self.len += 1
        return self
    
    def all_bits_set(self):
        if self.len == 0:
            return False
        tmp = mpz(0)
        tmp = tmp.bit_set(self.len)
        return tmp == self.x + 1
    
    def to_array(self):
        array = np.zeros((self.len,))
        next_1 = self.x.bit_scan1(0)
        while next_1 != None:
            array[next_1] = 1
            next_1 = self.x.bit_scan1(next_1 + 1)
        return array
    
    def to_mpz(self):
        return self.x
    
    def from_array(self, arr):
        #for item in arr:
        for i in range(len(arr)):
            item = arr[i]
            if item == 1:
                self.x = self.x.bit_set(self.len)
            self.len += 1
        return self
