import numpy as np
import sklearn as sk
import torch
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.model_selection import KFold, StratifiedKFold
from sklearn.model_selection import cross_val_score
from gym import Env, spaces
from sac.replay_memory import ReplayMemory
import xgboost as xgb

class xgb_env_cont(Env):
    def __init__(self, x, y, obs_size, lb=(3, 0.001, 50, 0.05, 1, 0.6, 0.5, 0.5, 0, 0.01),
                 ub=(25, 0.1, 1200, 1, 7, 1, 1, 1, 1, 1), state_buffer_size=200):
        self.lb = np.asarray(lb)  # [100,3,2,1,0.1]
        self.ub = np.asarray(ub)  # [1200,30,100,100,0.9]
        self.obs_size = obs_size
        self.state_buffer_size = state_buffer_size
        self.x = x
        self.y = y
        self.reset()
        self.observation_space = spaces.Box(low=np.zeros(self.obs_size),
                                            high=np.ones(self.obs_size), dtype=float)
        low_acts = np.full(10, -1, dtype=float)
        high_acts = np.full(10, 1, dtype=float)
        self.action_space = spaces.Box(low=low_acts, high=high_acts, dtype=float)

    def score_RF(self, max_depth, eta, estimators, gamma, min_child_weight, subsample, colsample_bytree,
                 colsample_bylevel, reg_alpha, reg_lambda, eval=False, std=False):


        model = xgb.XGBClassifier(max_depth=round(max_depth), learning_rate=eta, n_estimators=round(estimators),
                                  gamma=gamma, min_child_weight=min_child_weight, subsample=subsample,
                                  colsample_bytree=colsample_bytree, colsample_bylevel=colsample_bylevel,
                                  reg_alpha=reg_alpha, reg_lambda=reg_lambda)

        kfold = KFold(n_splits=10, shuffle=True, random_state=69)
        results = cross_val_score(model, self.x, self.y, n_jobs=-1, cv=kfold)
        score = results.mean()

        if std:
            std = results.std()
            return score, std
        return score

    def step(self, action, eval=False):
        # convert action to actual values
        action = np.clip(action, -1, 1)

        action_scaled = np.squeeze((action + 1) / 2)
        real_action = np.float32(np.zeros(10))
        for i in range(10):
            real_action[i] = self.lb[i] + action_scaled[i] * (self.ub[i] - self.lb[i])


        reward = self.score_RF(real_action[0], real_action[1], real_action[2],
                               real_action[3], real_action[4], real_action[5],
                               real_action[6], real_action[7], real_action[8], real_action[9])

        self.state_buffer.push(np.concatenate((action, [reward])))
        next_state = np.asarray(self.state_buffer.return_all())
        self.state = next_state
        if len(self.state_buffer) > 10:
            done = True
        else:
            done = False
        return (next_state, reward, done)

    def reset(self):
        self.state_buffer = ReplayMemory(self.state_buffer_size)

        action = np.random.uniform(low=-1, high=1, size=10)
        real_action = np.float32(np.zeros(10))
        action_scaled = np.squeeze((action + 1) / 2)
        for i in range(10):
            real_action[i] = self.lb[i] + action_scaled[i] * (self.ub[i] - self.lb[i])

        reward = self.score_RF(real_action[0], real_action[1], real_action[2],
                               real_action[3], real_action[4], real_action[5],
                               real_action[6], real_action[7], real_action[8], real_action[9])

        self.state_buffer.push(np.concatenate((action, [reward])))
        self.state = np.asarray(self.state_buffer.return_all())
        return self.state

