import gym

import riskfolio as rp
import numpy as np
import sys
import os
sys.path.insert(1, '/spsa-main')
from spsa import _spsa, iterator


import warnings
warnings.filterwarnings("ignore")

class GaussianPolicy:
    """
    A simple Gaussian policy for continuous actions.
    """
    def __init__(self, state_dim, weights = None):
        """
        Initializes the policy with random weights for mean and standard deviation.

        Args:
            state_dim: The dimensionality of the state space.
        """
        self.state_dim = state_dim
        self.mean_weight = np.array([None] * action_dim)
        self.std_weight = np.array([None] * action_dim)
        if weights is not None:

            for i in range(action_dim):
                self.mean_weight[i] = weights[i][0]
                self.std_weight[i] = weights[i][1]

        else:

            for i in range(action_dim):
                self.mean_weight[i] = np.random.rand(state_dim)
                self.std_weight[i] = np.random.rand(state_dim)


    def sample_action(self, state):
        """
        Samples an action from a Gaussian distribution based on the state.

        Args:
            state: A numpy array representing the current state.

        Returns:
            A sampled action from the Gaussian distribution.
        """
        mean = [None] * action_dim
        std = [None] * action_dim
        action = [None] * action_dim
        for i in range(action_dim):

            mean[i] = np.dot(state, self.mean_weight[i])
            std[i] = np.exp(np.dot(state, self.std_weight[i]))
            action[i] = np.random.normal(mean[i], std[i])

        return action

    def update_weights(self, reward, new_state, old_state, learning_rate, discount_factor):
        """
        Updates the mean and standard deviation weights based on reward.

        Args:
            reward: The reward received from the environment.
            new_state: The new state after taking an action.
            old_state: The previous state before taking an action.
            learning_rate: The learning rate for updating weights.
            discount_factor: The discount factor for future rewards.
        """
        td_error = reward - discount_factor * np.dot(new_state, self.mean_weight)
        # Update weights based on temporal difference error and state features
        self.mean_weight += learning_rate * td_error * old_state
        self.std_weight += learning_rate * td_error * old_state * (new_state - old_state)

    def update_parameters(self, new_weights):


        c = np.asarray(self.mean_weight)
        d = np.asarray(self.std_weight)

        for i in range(action_dim):
            assert new_weights[i][0].shape == self.mean_weight[i].shape, "Shape mismatch in update_parameters"
            assert new_weights[i][1].shape == self.std_weight[i].shape, "Shape mismatch in update_parameters"
            # assert new_weights[i].shape == self.std_weight.shape, "Shape mismatch in update_parameters"
            c[i] = new_weights[i][0]
            d[i] = new_weights[i][1]

        self.mean_weight = c
        self.std_weight = d


def calculate_evar(returns, alpha=0.1, solver='CLARABEL'):
    returns = np.array(returns)
    neg_returns = np.array(-returns)
    res = rp.RiskFunctions.EVaR_Hist(neg_returns, alpha)
    return res


# Register gym environment. By specifying kwargs,
# you are able to choose which patient to simulate.
# patient_name must be 'adolescent#001' to 'adolescent#010',
# or 'adult#001' to 'adult#010', or 'child#001' to 'child#010'
from gym.envs.registration import register
register(
    id='simglucose-adolescent1-v0',
    entry_point='simglucose.envs:T1DSimEnv',
    kwargs={'patient_name': 'adolescent#001'}
)

env = gym.make('simglucose-adolescent1-v0')
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
env.render(mode='human')
returns = []

params = None

# Create the Gaussian policy object
policy_cont = GaussianPolicy(state_dim, params)

pvar = False

pol_evar = 0.0
beta = 0.1
avg_return_values = []

def objective_function(params : np.ndarray) -> float:
    global avg_return_values
    global pvar
    global alpha_val
    global pol_evar
    c = np.asarray(policy_cont.mean_weight)
    d = np.asarray(policy_cont.std_weight)
    for i in range(action_dim):
        c[i] = params[i][0]
        d[i] = params[i][1]
    policy_cont.mean_weight = c
    policy_cont.std_weight = d
    for _ in range(1):
        observation = env.reset()
        
        print(observation)

        steps = 0
        # total_reward = 0
        done = False
        R = 0
        while not done:
            env.render(mode='human')
            # action = [policy_cont.sample_action(state)]
            # print(np.mean(policy_cont.sample_action(observation)[0]))
            action = np.mean(policy_cont.sample_action(observation))
            # print(action)
            # action = env.action_space.sample()
            observation, reward, done, info = env.step(action)
            if done:
                # print("Episode finished after {} timesteps".format(t + 1))
                break

            # action = np.array(policy_cont.sample_action(state),dtype = "float32")
            # # adding noise to action
            # noise = [None] * action_dim
            # for i in range(action_dim):
            #     noise[i] = np.random.normal(0, alpha_val)

            # action = action + np.array(noise)
            # action = np.clip(action, -1, 1)

            # new_state, reward, terminated, truncated, _  = env.step(action)
            # print(reward)
            R = R + reward
            # state = new_state
            steps += 1

            

        returns.append(R)
    # print("outside for")
    # print(returns)
    evar,_ = calculate_evar(returns, 0.1)
    pol_evar = pol_evar + beta * (evar - pol_evar)
    avg_return = np.mean(returns)
    avg_return_values.append(avg_return)
    # print("end of objective")
    print("poly evar", pol_evar, " evar",evar)
    return pol_evar


def main():
    global pvar
    global policy
    f = objective_function
    x = []
    max_iterations = 1000
    for _ in range(action_dim):

        x.append([np.random.rand(state_dim), np.random.rand(state_dim)])

    objective_values = []
    directory_path = './tests/mountaincar/alpha_0_9/'
    filename = 'cvar5.csv'
    if not os.path.exists(directory_path):
        os.makedirs(directory_path)

    try:
        cnt = 0
        for variables in iterator.maximize(f, x, lr=0.01, adam=True):
            print("Iteration [ ", cnt, " ]")
            pvar = True
            # print("weights : ", variables['x'])
            obj_val = f(variables['x'])
            print(" obj. val : ", obj_val)
            objective_values.append(obj_val)
            pvar = False
            policy_cont.update_parameters(variables['x'])
            cnt += 1
            sys.stdout.flush()
            if cnt==max_iterations:
                break

            # if cnt % 1000 == 0:
            #     with open(directory_path + filename, 'w', newline='') as csvfile:
            #         writer = csv.writer(csvfile)
            #         writer.writerow(objective_values)
            #         csvfile.flush()
            #     print(f"checkpoint for objective values, iteration: {cnt}. Saved at ", directory_path, filename)


    except KeyboardInterrupt:
        sys.exit()

if __name__ == "__main__":
    main()
