"""
This file defines the base class for an environment and a specific data-science environment.
"""
from sklearn.metrics import accuracy_score
import pandas as pd
import importlib.util
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import pandas as pd
import importlib
import sys
import random
import numpy as np
# import openml
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import random
import numpy as np
from sklearn.exceptions import NotFittedError
import time
from timeout_decorator import timeout
import sys
import io
from executor import Executor
from sklearn.metrics import accuracy_score
from sklearn.preprocessing import LabelEncoder
import numpy as np
from sklearn.metrics import f1_score 
import json
import ast


def get_env(env_name, config, logger, seed):
    return TestGeneratedEnvEnvironment(config, logger, env_name, seed)

def maybe_fit_preprocessor(preprocessor, X_train):
    # Attempt to transform a single record to see if preprocessor has been fitted
    try:
        preprocessor.transform(X_train.iloc[:1])
    except NotFittedError:
        # If it's not fitted, then fit it to the training data
        preprocessor.fit(X_train)

class Environment:
    def __init__(self, config, logger, env_name, seed):
        self.config = config
        self.logger = logger
        self.env_name = env_name
        self.seed = seed

    def log(self, message):
        if self.logger is not None:
            self.logger.info(f"[Environment: {self.env_name}] {message}")

    def reset(self):
        pass

    def step(self, action):
        pass

    
class TestGeneratedEnvEnvironment(Environment):
    def __init__(self, config, logger, env_name, seed):
        super().__init__(config, logger, env_name, seed)
        self.seed_value = None
        self.description = None
        self.attribute_names = None
        self.prepend_code_libraries = 'from pandas import DataFrame, Series\nfrom sklearn.base import BaseEstimator\nfrom sklearn.compose import ColumnTransformer\nfrom typing import Tuple\nimport pandas as pd\nimport numpy as np\nfrom sklearn.preprocessing import LabelEncoder, StandardScaler\nfrom sklearn.impute import SimpleImputer\nfrom sklearn.preprocessing import OneHotEncoder'

    def set_seed(self, seed_value):
        self.seed_value = seed_value
        random.seed(seed_value)
        np.random.seed(seed_value)

    def reset(self, config={}, logger=None):
        # Fetch the dataset
        if self.env_name == 'covid-gcp':
            from libs.covid_cgp.data_utils import load_data
            self.train_loader, self.val_loader, self.train_data, self.test_data, self.description, self.attribute_names = load_data(regress_only_on_y=True)
            self.executor =  Executor(variables={'train_loader': self.train_loader,
                                        'val_loader': self.val_loader,
                                        'train_data': self.train_data.copy(),
                                        },
                                        prepend_code_libraries=self.prepend_code_libraries)
     
        elif 'Covid-scenario' in self.env_name:
            from libs.covasim.publichealth import load_data
            self.train_data, self.val_data, self.test_data, self.description = load_data()
            self.executor =  Executor(variables={'train_data': self.train_data,
                                                'val_data': self.val_data,
                                                'test_data': self.test_data,
                                        },
                                        prepend_code_libraries=self.prepend_code_libraries)
        elif 'PNAS' in self.env_name:
            from libs.covasim.env_pnas import load_data
            self.train_data, self.val_data, self.test_data = load_data('new_york')
            self.executor =  Executor(variables={'train_data': self.train_data,
                                                'val_data': self.val_data,
                                                'test_data': self.test_data,
                                        },
                                        prepend_code_libraries=self.prepend_code_libraries)
        else:
            raise Exception(f'Environment {self.env_name} not found')
                            
    def evaluate_simulator_code(self, StateDifferential, config={}, logger=None):
        if 'Covid-scenario' in self.env_name:
            from libs.covasim.publichealth import CovidSEnv
            gt_env = CovidSEnv()
            return gt_env.evaluate_simulator_code_wrapper(StateDifferential=StateDifferential, train_data=self.train_data, val_data=self.val_data, test_data=self.test_data, config=config, logger=logger, env_name=self.env_name)
        elif 'PNAS' in self.env_name:
            from libs.covasim.env_pnas import PNASEnv
            gt_env = PNASEnv()
            return gt_env.evaluate_simulator_code_wrapper(StateDifferential=StateDifferential, train_data=self.train_data, val_data=self.val_data, test_data=self.test_data, config=config, logger=logger, env_name=self.env_name)
        else:
            raise Exception(f'Environment {self.env_name} not found')
        
    
    def evaluate_simulator_code_on_test_dataset(self, StateDifferential, config={}, logger=None):
        if 'Covid-scenario' in self.env_name:
            from libs.covasim.publichealth import CovidSEnv
            gt_env = CovidSEnv()
            env_state_diff_function_input = env_state_diff_function.replace('np.', 'jnp.')
            code_string = f'{self.prepend_code_libraries}\nimport jax.numpy as jnp\n{env_state_diff_function_input}'
            user_code_module = importlib.types.ModuleType("user_code")

                
            exec(code_string, user_code_module.__dict__)
            return gt_env.evaluate_simulator_code_on_dataset(user_code_module.d_state__dt, self.test_data, config=config, logger=logger, env_name=self.env_name)
        elif 'PNAS' in self.env_name:
            from libs.covasim.env_pnas import PNASEnv
            gt_env = PNASEnv()
            env_state_diff_function_input = env_state_diff_function.replace('np.', 'jnp.')
            code_string = f'{self.prepend_code_libraries}\nimport jax.numpy as jnp\n{env_state_diff_function_input}'
            user_code_module = importlib.types.ModuleType("user_code")

                
            exec(code_string, user_code_module.__dict__)
            return gt_env.evaluate_simulator_code_on_dataset(user_code_module.d_state__dt, self.test_data, config=config, logger=logger, env_name=self.env_name)
        else:
            raise Exception(f'Environment {self.env_name} not found')
        

    def get_obs(self):
        if self.config.get('use_description', False):
            # state = f"description:{self.description}\nattribute_names:{self.attribute_names}"
            state = f"{self.attribute_names}"
            return state
        else:
            return None
        
    # @timeout(60*3, timeout_exception=StopIteration)
    def execute_user_code_lines(self, code_string):
        # try:
        result = self.executor.execute_user_code_lines(code_string)
        # except Exception as e:
        #     print(e)
        return result
        

    def execute_final_user_code(self, code_string):
        code_string = f'{self.prepend_code_libraries}\n{code_string}'
        user_code_module = importlib.types.ModuleType("user_code")   
        
        
        
        exec(code_string, user_code_module.__dict__)

        # Train the model using the provided code
        self.log('Training model...')
        t0 = time.perf_counter()
        model, preprocessor = user_code_module.train_model(self.train_data.copy(), self.train_labels)
        self.log(f"Model trained. Elapsed time: {time.perf_counter() - t0}s]")
        if preprocessor is not None:
            # Preprocess the test data using the same transformer
            maybe_fit_preprocessor(preprocessor, self.train_data.copy())
            test_data_processed = preprocessor.transform(self.test_data)
            # Predict using the trained model
            try:
                predictions = model.predict(test_data_processed)
            except ValueError as e:
                predictions = model.predict(self.test_data)
            # Calculate the reward as the accuracy of the prediction
            reward = f1_score(self.test_labels, predictions, average='macro')
        else:
            # Predict using the trained model
            predictions = model.predict(self.test_data)
            # Calculate the reward as the accuracy of the prediction
            reward = f1_score(self.test_labels, predictions, average='macro')
        # except Exception as e:
        #     print(f"An error occurred in the user-generated code: {e}")
        #     reward = 0
        return reward

    def step(self, code_string):
        try:
            reward = self.execute_final_user_code(code_string)
        except StopIteration:
            reward = 0

        state_out = None
        done = True
        info = None
        return state_out, reward, done, info


import unittest

class TestEnvironment(unittest.TestCase):
    def test_environment_iris(self):
        config = {'dataset_id': 'iris'}
        env = DataScienceEnvironment(config, None, 'data-science-iris')
        env.seed(42)
        env.reset()

        user_code_string = r"""
from sklearn.linear_model import LogisticRegression

def train_model(X_train, y_train):
    model = LogisticRegression()
    model.fit(X_train, y_train)
    return model, None"""
        state_out, reward, done, info = env.step(user_code_string)
        expected_accuracy = 1.0 # Set this to an expected value for your specific dataset and model
        self.assertAlmostEqual(reward, expected_accuracy, delta=0.01) # Adjust delta as needed
