import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
import os

from meta_test_algo.network import es_policy


class EvolutionStrategies_CEM:
    def __init__(self, 
                 obs_dim, 
                 action_dim,
                 net_size,
                 latent_action_dim,
                 device,
                 es_params,
                 **kwargs):

        self.max_path_length = kwargs['max_path_length']
        self.reward_scale = kwargs['reward_scale']

        self.obs_dim = obs_dim
        self.action_dim = action_dim
        self.device = device

        self.n_rollouts = es_params['n_rollouts']
        self.noise_sigma = es_params['noise_sigma']
        self.lr = es_params['lr']
        self.elite_frac = es_params['elite_frac']
        
        self.policy = es_policy(obs_dim,action_dim,net_size,latent_action_dim).to(self.device)
        self.head_param = nn.utils.parameters_to_vector(self.policy.last_layer.parameters()).detach().cpu().numpy()
        self.param_shape = self.head_param.shape

    def load_shared_network(self,policy_params):
        self.policy.shared_layer.load_state_dict(policy_params)
        for param in self.policy.shared_layer.parameters():
            param.requires_grad = False
    
    def evaluate_head(self,env):
        obs = env.reset()
        env_step = 0
        total_reward = 0
        while env_step<self.max_path_length:
            env_step += 1
            action = self.select_action(obs)
            obs, reward, done, env_info = env.step(action)
            total_reward += reward
            if done:
                break
        return total_reward
        
    def es_adapt_head(self, env):
        noise_list = []
        rewards = []

        for _ in range(self.n_rollouts):
            epsilon = np.random.randn(*self.param_shape)
            perturbed_params = self.head_param + self.noise_sigma * epsilon
            nn.utils.vector_to_parameters(
                torch.tensor(perturbed_params, dtype=torch.float32).to(self.device),
                self.policy.last_layer.parameters()
            )
            r = self.evaluate_head(env)
            noise_list.append(epsilon)
            rewards.append(r)

        rewards = np.array(rewards)
        rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-6)

        num_elites = max(1, int(self.n_rollouts * self.elite_frac))
        elite_indices = np.argsort(rewards)[-num_elites:]
        elite_epsilons = [noise_list[i] for i in elite_indices]
        elite_rewards = [rewards[i] for i in elite_indices]

        grad_estimate = np.zeros_like(self.head_param)
        for eps, r in zip(elite_epsilons, elite_rewards):
            grad_estimate += r * eps
        grad_estimate /= num_elites
        "이미 사용한 epsilon과 reward 이용해서, gradient계산에 다시 이용 가능? 단, epsilon은 현재 head_param으로부터 다시 계산"
        self.head_param += self.lr * grad_estimate
        nn.utils.vector_to_parameters(
            torch.tensor(self.head_param, dtype=torch.float32).to(self.device),
            self.policy.last_layer.parameters()
        )
        self.head_param = nn.utils.parameters_to_vector(
            self.policy.last_layer.parameters()
        ).detach().cpu().numpy()

    
    def select_action(self,obs):
        obs = torch.tensor(obs, dtype=torch.float32).to(self.device)
        action = self.policy(obs)
        return action.cpu().numpy()

    def collet_data_and_train_filter(self,env):
        self.es_adapt_head(env)

    def evaluation(self,env):
        with torch.no_grad():
            n_eval_epi = 3
            returns = 0
            for _ in range(n_eval_epi):
                episode_returns = 0
                env_step = 0
                o = env.reset()
                while env_step<self.max_path_length:
                    env_step += 1
                    a = self.select_action(o)
                    next_o, r, d, env_info = env.step(a)
                    episode_returns += r
                    o = next_o
                    if d:
                        break
                returns += episode_returns
        return returns/n_eval_epi   