'''
Generative Model for Batch Constrained Q-learning (BCQ)

Follows examples from:
- https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html
- https://github.com/sfujim/BCQ/blob/4876f7e5afa9eb2981feec5daf67202514477518/discrete_BCQ/discrete_BCQ.py#L5
'''

import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import RepeatedStratifiedKFold, GridSearchCV
from sklearn.utils import check_array
import logging
from .Utils.rl_utils import ACTIONS, STATE_SPACE, TAU

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

class GenerativeModel:
    def __init__(self) -> None:
        """Initialize the generative model with states and actions from replay memory."""
        self.states = []
        self.actions = []
        self.clf = None
        self._load_and_train()

    def _load_and_train(self) -> None:
        """Load data from replay_memory.txt and train the logistic regression model."""
        try:
            with open("replay_memory.txt", "r") as file:
                for line in file:
                    if '\x00' not in line:
                        transition = eval(line.rstrip())
                        state_feat = transition[0]
                        if len(state_feat) == STATE_SPACE:  # Validate state size
                            self.states.append(state_feat)
                            self.actions.append(ACTIONS.index(transition[1]))
            
            if not self.states:
                raise ValueError("No valid transitions found in replay_memory.txt")
            
            # Convert to numpy arrays with efficient types
            self.states = np.array(self.states, dtype=np.float32)
            self.actions = np.array(self.actions, dtype=np.int32)
            
            # Validate data
            self._validate_data()
            
            self.clf = LogisticRegression(penalty='l1', C=0.1, solver='liblinear', max_iter=1000)
            self.clf.fit(self.states, self.actions)
            logging.info("Generative model trained with %d samples", len(self.states))
        except FileNotFoundError:
            logging.error("replay_memory.txt not found")
            raise
        except Exception as e:
            logging.error("Error loading or training generative model: %s", e)
            raise

    def _validate_data(self) -> None:
        """Validate states and actions for NaNs, infinities, and shape."""
        try:
            self.states = check_array(self.states, dtype=np.float32, ensure_2d=True)
            self.actions = check_array(self.actions, dtype=np.int32, ensure_2d=False)
            if self.states.shape[1] != STATE_SPACE:
                raise ValueError(f"Expected state dimension {STATE_SPACE}, got {self.states.shape[1]}")
            if len(self.states) != len(self.actions):
                raise ValueError(f"Mismatch: {len(self.states)} states vs {len(self.actions)} actions")
            logging.info("Data validation passed: %d samples, state shape %s", len(self.states), self.states.shape)
        except Exception as e:
            logging.error("Data validation failed: %s", e)
            raise

    def get_qmax(self, current_state: np.ndarray, q_values: np.ndarray) -> float:
        """Compute the maximum Q-value constrained by the generative model’s policy."""
        if self.clf is None:
            raise ValueError("Generative model not trained. Call _load_and_train first.")
        
        p = self.clf.predict_proba([current_state])[0]  # Probabilities for each action
        p = p > TAU  # Threshold probabilities
        max_future_q = np.max(p * q_values + (1. - p) * -1e8)  # Mask improbable actions
        return max_future_q

    def get_action(self, current_state: np.ndarray) -> int:
        """Predict the most likely action for the given state."""
        if self.clf is None:
            raise ValueError("Generative model not trained. Call _load_and_train first.")
        
        p = self.clf.predict_proba([current_state])[0]
        return np.argmax(p)

    def grid_search(self) -> dict:
        """Perform grid search to optimize LogisticRegression hyperparameters."""
        if not self.states.any() or not self.actions.any():
            logging.error("No data available for grid search. Load data first.")
            return {}

        model = LogisticRegression(max_iter=1000)
        cv = RepeatedStratifiedKFold(n_splits=5, n_repeats=2, random_state=1)  # Reduced splits/repeats
        # Define solver-specific parameter grids to avoid invalid combinations
        param_grid = [
            {
                'solver': ['liblinear'],
                'penalty': ['l1', 'l2'],
                'C': [1e-3, 1e-2, 1e-1, 1, 10]  # Reduced C values
            },
            {
                'solver': ['newton-cg', 'lbfgs'],
                'penalty': ['l2'],
                'C': [1e-3, 1e-2, 1e-1, 1, 10]
            }
        ]

        search = GridSearchCV(model, param_grid, scoring='accuracy', n_jobs=2, cv=cv, verbose=1, error_score='raise')
        try:
            result = search.fit(self.states, self.actions)
            logging.info("Grid search completed - Best Score: %s, Best Hyperparameters: %s", 
                         result.best_score_, result.best_params_)
            return {'best_score': result.best_score_, 'best_params': result.best_params_}
        except Exception as e:
            logging.error("Grid search failed: %s", e)
            return {}

if __name__ == "__main__":
    
    obj = GenerativeModel()
    # Uncomment to run grid search
    # result = obj.grid_search()
    #     if result:
    #         print(f"Best Score: {result['best_score']}, Best Params: {result['best_params']}")