import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
import os

from meta_test_algo.network import es_policy2
from meta_test_algo.base import base

class CMA:
    def __init__(self, dim, sigma=0.5, popsize=10):
        self.dim = dim
        self.sigma = sigma
        self.popsize = popsize
        self.mean = np.zeros(dim)
        self.C = np.ones(dim)  # diagonal variance
        self.gen = 0

    def ask(self):
        self.gen += 1
        z = np.random.randn(self.popsize, self.dim)
        samples = self.mean + self.sigma * z * np.sqrt(self.C)
        return samples, z

    def tell(self, rewards, z):
        idx = np.argsort(rewards)[::-1]
        top_z = z[idx[:self.popsize // 2]]
        self.mean = 0*np.mean(top_z, axis=0)
        cov_update = np.var(top_z, axis=0)
        self.C = 0.5 * self.C + 0.5 * cov_update

class EvolutionStrategies_CMA(base):
    def __init__(self, 
                 obs_dim, 
                 action_dim,
                 net_size,
                 latent_action_dim,
                 device,
                 es_params,
                 **kwargs):
        super().__init__(obs_dim,
                         action_dim,
                         net_size,
                         latent_action_dim,
                         device,
                         **kwargs)

        self.max_path_length = kwargs['max_path_length']
        self.reward_scale = kwargs['reward_scale']

        self.obs_dim = obs_dim
        self.action_dim = action_dim
        self.device = device

        self.n_rollouts = es_params['n_rollouts']
        self.noise_sigma = es_params['noise_sigma']
        self.lr = es_params['lr']
        self.elite_frac = es_params['elite_frac']
        
        # 초기화 부분
        self.policy = es_policy2(obs_dim, action_dim, net_size, latent_action_dim).to(self.device)

        # ✅ 수정됨: first_layer + last_layer 전체 파라미터를 벡터로 저장
        self.all_head_params = list(self.policy.first_layer.parameters()) + list(self.policy.last_layer.parameters())
        self.head_param = nn.utils.parameters_to_vector(self.all_head_params).detach().cpu().numpy()  # shape: (total param,)
        self.param_shape = self.head_param.shape[0]

        self.cma = CMA(self.param_shape,self.noise_sigma,self.n_rollouts)


    def es_adapt_head(self, env):
        total_steps = 0
        solutions, z = self.cma.ask()
        rewards = []
        for x in solutions:
            perturbed_params = self.head_param + x
            torch.nn.utils.vector_to_parameters(
                torch.tensor(perturbed_params, dtype=torch.float32).to(self.device),
                list(self.policy.first_layer.parameters()) + list(self.policy.last_layer.parameters())
            )
            r, steps = self.evaluate_head(env)
            rewards.append(r)
            total_steps += steps
        self.cma.tell(rewards,z)
        rewards = np.array(rewards)
        rewards = (rewards - rewards.mean())/rewards.std()
        
        grad_estimate = np.zeros_like(self.head_param)
        for eps, r in zip(solutions,rewards):
            grad_estimate += r*eps
        grad_estimate /= len(solutions)

        self.head_param += self.lr*grad_estimate

        nn.utils.vector_to_parameters(
            torch.tensor(self.head_param, dtype=torch.float32).to(self.device),
            list(self.policy.first_layer.parameters()) + list(self.policy.last_layer.parameters())
        )

        return total_steps


    def collet_data_and_train_filter(self,env):
        env_steps = self.es_adapt_head(env)
        return env_steps
 