import numpy as np
import sklearn as sk
import torch
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from gym import Env, spaces
from sac.replay_memory import ReplayMemory
from sklearn.model_selection import KFold, StratifiedKFold
from sklearn.model_selection import cross_val_score

class rf_env_cont(Env):
    def __init__(self, x, y, obs_size, lb=(100, 3, 2, 1, 0.1),
                 ub=(1200, 30, 100, 100, 0.9), state_buffer_size=200):
        self.lb = np.asarray(lb)  # [100,3,2,1,0.1]
        self.ub = np.asarray(ub)  # [1200,30,100,100,0.9]
        self.obs_size = obs_size
        self.state_buffer_size = state_buffer_size
        self.x = x
        self.y = y
        self.reset()
        self.observation_space = spaces.Box(low=np.zeros(self.obs_size),
                                            high=np.ones(self.obs_size), dtype=float)
        low_acts = np.full(5, -1, dtype=float)
        high_acts = np.full(5, 1, dtype=float)
        self.action_space = spaces.Box(low=low_acts, high=high_acts, dtype=float)

    def score_RF(self, estimators, max_depth, min_split, min_leaf, max_feat, eval=False):
        estimators = round(estimators)
        max_depth = round(max_depth)
        min_split = round(min_split)
        min_leaf = round(min_leaf)

        model = sk.ensemble.RandomForestClassifier(estimators, max_depth=max_depth,
                                                min_samples_split=min_split,
                                                min_samples_leaf=min_leaf,
                                                max_features=max_feat)

        kfold = KFold(n_splits=10, shuffle=True, random_state=69)
        results = cross_val_score(model, self.x, self.y, n_jobs=-1, cv=kfold)
        score = results.mean()
        return score

    def step(self, action, eval=False):
        # convert action to actual values
        action = np.clip(action, -1, 1)
        action_scaled = np.squeeze((action + 1) / 2)
        real_action = np.float32(np.zeros(5))
        #print(action)
        for i in range(5):
            real_action[i] = self.lb[i] + action_scaled[i] * (self.ub[i] - self.lb[i])

        reward = self.score_RF(real_action[0], real_action[1], real_action[2],
                               real_action[3], real_action[4], eval)

        self.state_buffer.push(np.concatenate((action, [reward])))
        next_state = np.asarray(self.state_buffer.return_all())
        self.state = next_state
        if len(self.state_buffer) > 10:
            done = True
        else:
            done = False
        return (next_state, reward, done)

    def reset(self):
        self.state_buffer = ReplayMemory(self.state_buffer_size)

        action = np.random.uniform(low=-1, high=1, size=5)
        real_action = np.float32(np.zeros(5))
        action_scaled = np.squeeze((action + 1) / 2)
        for i in range(5):
            real_action[i] = self.lb[i] + action_scaled[i] * (self.ub[i] - self.lb[i])

        reward = self.score_RF(real_action[0], real_action[1], real_action[2],
                               real_action[3], real_action[4])

        self.state_buffer.push(np.concatenate((action, [reward])))
        self.state = np.asarray(self.state_buffer.return_all())
        return self.state

