import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
import os

from meta_test_algo.network import es_policy2_nonlinear
from meta_test_algo.base import base

class EvolutionStrategies_MIRROR2_nonlinear(base):
    def __init__(self, 
                 obs_dim, 
                 action_dim,
                 net_size,
                 latent_action_dim,
                 device,
                 es_params,
                 **kwargs):
        super().__init__(obs_dim,
                         action_dim,
                         net_size,
                         latent_action_dim,
                         device,
                         **kwargs)

        self.max_path_length = kwargs['max_path_length']
        self.reward_scale = kwargs['reward_scale']

        self.obs_dim = obs_dim
        self.action_dim = action_dim
        self.device = device

        self.n_rollouts = es_params['n_rollouts']
        self.noise_sigma = es_params['noise_sigma']
        self.lr = es_params['lr']
        self.elite_frac = es_params['elite_frac']
        
        # 초기화 부분
        self.policy = es_policy2_nonlinear(obs_dim, action_dim, net_size, latent_action_dim).to(self.device)

        # ✅ 수정됨: first_layer + last_layer 전체 파라미터를 벡터로 저장
        all_head_params = list(self.policy.first_layer.parameters()) + list(self.policy.second_layer.parameters()) + list(self.policy.last_layer1.parameters()) + list(self.policy.last_layer2.parameters())
        self.head_param = nn.utils.parameters_to_vector(all_head_params).detach().cpu().numpy()  # shape: (total param,)
        self.param_shape = self.head_param.shape


    def es_adapt_head(self, env):
        pair_list = []
        total_steps = 0
        for _ in range(self.n_rollouts // 2):
            epsilon = np.random.randn(*self.param_shape)

            # ✅ 수정됨: θ + ε
            perturbed_params_pos = self.head_param + self.noise_sigma * epsilon
            all_params = list(self.policy.first_layer.parameters()) + list(self.policy.second_layer.parameters()) + list(self.policy.last_layer1.parameters()) + list(self.policy.last_layer2.parameters())  # ✅
            nn.utils.vector_to_parameters(
                torch.tensor(perturbed_params_pos, dtype=torch.float32).to(self.device),
                all_params
            )
            r_pos,steps = self.evaluate_head(env)
            total_steps += steps

            # ✅ 수정됨: θ - ε
            perturbed_params_neg = self.head_param - self.noise_sigma * epsilon
            nn.utils.vector_to_parameters(
                torch.tensor(perturbed_params_neg, dtype=torch.float32).to(self.device),
                all_params
            )
            r_neg,steps = self.evaluate_head(env)
            total_steps += steps

            r_diff = r_pos - r_neg
            r_mean = (r_pos + r_neg) / 2.0
            pair_list.append((epsilon, r_diff, r_mean))
        

        " select elite set "
        num_elites = max(1, int(len(pair_list) * self.elite_frac))
        elite_pairs = sorted(pair_list, key=lambda x: x[2])[-num_elites:]

        " normalize r_diffs "
        r_diffs = np.array([r_diff for _, r_diff, _ in elite_pairs])
        r_diffs = (r_diffs - r_diffs.mean()) / (r_diffs.std() + 1e-6)

        " estimate gradient "
        grad_estimate = np.zeros_like(self.head_param)
        for (eps, _, _), r in zip(elite_pairs, r_diffs):
            grad_estimate += r * eps
        grad_estimate /= num_elites

        " update parameters "
        # 업데이트 마지막 단계
        # ✅ 수정됨: 전체 head_param 업데이트
        self.head_param += self.lr * grad_estimate
        
        # ✅ 수정됨: first + last layer 모두 업데이트
        nn.utils.vector_to_parameters(
            torch.tensor(self.head_param, dtype=torch.float32).to(self.device),
            list(self.policy.first_layer.parameters()) + list(self.policy.second_layer.parameters()) + list(self.policy.last_layer1.parameters()) + list(self.policy.last_layer2.parameters())
        )

        return total_steps
    
    def collet_data_and_train_filter(self,env):
        env_steps = self.es_adapt_head(env)
        return env_steps
 