import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
import os

from meta_test_algo.network import es_policy2, q_function
from meta_test_algo.base import base
from meta_test_algo.buffer import ReplayBuffer

class Q_finetuning(base):
    def __init__(self, 
                 obs_dim, 
                 action_dim,
                 net_size,
                 latent_action_dim,
                 device,
                 es_params,
                 **kwargs):
        super().__init__(obs_dim,
                         action_dim,
                         net_size,
                         latent_action_dim,
                         device,
                         **kwargs)

        self.max_path_length = kwargs['max_path_length']
        self.reward_scale = kwargs['reward_scale']

        self.obs_dim = obs_dim
        self.action_dim = action_dim
        self.device = device

        self.n_rollouts = es_params['n_rollouts']
        self.n_rollouts = 10
        self.noise_sigma = es_params['noise_sigma']
        self.lr = es_params['lr']
        self.elite_frac = es_params['elite_frac']
        
        # 초기화 부분
        self.policy = es_policy2(obs_dim, action_dim, net_size, latent_action_dim).to(self.device)
        self.q_function1 = q_function(obs_dim+action_dim, net_size).to(self.device)
        self.q_function2 = q_function(obs_dim+action_dim, net_size).to(self.device)
        self.target_q_function1 = q_function(obs_dim+action_dim, net_size).to(self.device)
        self.target_q_function2 = q_function(obs_dim+action_dim, net_size).to(self.device)
        self.target_q_function1.load_state_dict(self.q_function1.state_dict())
        self.target_q_function2.load_state_dict(self.q_function2.state_dict())

        self.q_optimizer = optim.Adam(list(self.q_function1.parameters())+
                                      list(self.q_function2.parameters()),
                                      lr=1e-4)
        

        # ✅ 수정됨: first_layer + last_layer 전체 파라미터를 벡터로 저장
        all_head_params = list(self.policy.first_layer.parameters()) + list(self.policy.last_layer.parameters())
        self.head_param = nn.utils.parameters_to_vector(all_head_params).detach().cpu().numpy()  # shape: (total param,)
        self.param_shape = self.head_param.shape

        self.memory = ReplayBuffer(200000, self.device)
        self.batch_size = 512
        self.discount = 0.99
        self.steps = 0

    def evaluate_head(self,env):
        obs = env.reset()
        env_step = 0
        total_reward = 0
        while env_step<self.max_path_length:
            env_step += 1
            action = self.select_action(obs)
            next_obs, reward, done, env_info = env.step(action)
            total_reward += reward
            self.memory.add(obs, action, reward, next_obs, done)
            obs = next_obs
            if done:
                break
        return total_reward,env_step
        
    def update_q_function(self):
        if len(self.memory) < self.batch_size:
            return 0

        obs, actions, rewards, next_obs, dones = self.memory.sample(self.batch_size)

        with torch.no_grad():
            next_actions = self.policy(next_obs)
            next_q1 = self.target_q_function1(next_obs, next_actions)
            next_q2 = self.target_q_function2(next_obs, next_actions)
            next_q = torch.min(next_q1,next_q2)

        target_q = rewards + self.discount*(1-dones)*next_q
        td_error1 = target_q - self.q_function1(obs,actions)
        td_error2 = target_q - self.q_function2(obs,actions)
        q_loss1 = (td_error1 ** 2).mean()
        q_loss2 = (td_error2 ** 2).mean()
        q_loss = q_loss1+q_loss2
        # Q-function 업데이트
        self.q_optimizer.zero_grad()
        q_loss.backward()
        self.q_optimizer.step()

        self.soft_update_target()
        self.update_policy(obs)
        return q_loss.item()
    
    def soft_update_target(self, tau=0.005):
        for target_param, param in zip(self.target_q_function1.parameters(), self.q_function1.parameters()):
            target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)

        for target_param, param in zip(self.target_q_function2.parameters(), self.q_function2.parameters()):
            target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)

    def update_policy(self, obs):
        actions = self.policy.grad_action(obs)
        q_values = torch.min(self.q_function1(obs, actions),
                             self.q_function2(obs, actions))
        policy_loss = -q_values.mean()
        grad = torch.autograd.grad(policy_loss, list(self.policy.first_layer.parameters()) + list(self.policy.last_layer.parameters()))
        grad_vector = torch.cat([g.view(-1) for g in grad]).detach().cpu().numpy()

        filter_params = list(self.policy.first_layer.parameters()) + list(self.policy.last_layer.parameters())
        filter_params = nn.utils.parameters_to_vector(filter_params).detach().cpu().numpy()
        filter_params -= 0.0001*grad_vector
        nn.utils.vector_to_parameters(
                torch.tensor(filter_params, dtype=torch.float32).to(self.device),
                list(self.policy.first_layer.parameters()) + list(self.policy.last_layer.parameters())
            )
        
        "shared layer update"
        for param in self.policy.shared_layer.parameters():
            param.requires_grad = True
        actions = self.policy.grad_action(obs)
        q_values = torch.min(self.q_function1(obs, actions),
                             self.q_function2(obs, actions))
        policy_loss = -q_values.mean()
        grad = torch.autograd.grad(policy_loss, self.policy.shared_layer.parameters())
        grad_vector = torch.cat([g.view(-1) for g in grad]).detach().cpu().numpy()

        shared_params = nn.utils.parameters_to_vector(self.policy.shared_layer.parameters()).detach().cpu().numpy()
        shared_params -= 0.0001*grad_vector
        nn.utils.vector_to_parameters(
                torch.tensor(shared_params, dtype=torch.float32).to(self.device),
                self.policy.shared_layer.parameters()
            )

    
    def es_adapt_head(self, env):
        total_steps = 0

        for _ in range(1000):
            q_loss = self.update_q_function()
            # print("Q Loss:", q_loss)

        backup_params = list(self.policy.first_layer.parameters()) + list(self.policy.last_layer.parameters())
        backup_params = nn.utils.parameters_to_vector(backup_params).detach().cpu().numpy()

        for _ in range(self.n_rollouts // 2):
            epsilon = np.random.randn(*self.param_shape)

            # ✅ 수정됨: θ + ε
            perturbed_params_pos = self.head_param + self.noise_sigma * epsilon
            all_params = list(self.policy.first_layer.parameters()) + list(self.policy.last_layer.parameters())  # ✅
            nn.utils.vector_to_parameters(
                torch.tensor(perturbed_params_pos, dtype=torch.float32).to(self.device),
                all_params
            )
            _,steps = self.evaluate_head(env)
            total_steps += steps

            # ✅ 수정됨: θ - ε
            perturbed_params_neg = self.head_param - self.noise_sigma * epsilon
            nn.utils.vector_to_parameters(
                torch.tensor(perturbed_params_neg, dtype=torch.float32).to(self.device),
                all_params
            )
            _,steps = self.evaluate_head(env)
            total_steps += steps


        self.head_param = backup_params

        
        # ✅ 수정됨: first + last layer 모두 업데이트
        nn.utils.vector_to_parameters(
            torch.tensor(self.head_param, dtype=torch.float32).to(self.device),
            list(self.policy.first_layer.parameters()) + list(self.policy.last_layer.parameters())
        )

        self.steps += 1
        return total_steps
    
    def collet_data_and_train_filter(self,env):
        env_steps = self.es_adapt_head(env)
        return env_steps
 