from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Tuple, Callable
import numpy as np
import torch
import torch.nn as nn
from torch import tensor, float32, autograd
from collections import deque
from torch.optim.lr_scheduler import StepLR

def check_parameter_exists(model, parameter_name):
    for name, param in model.named_parameters():
        if name == parameter_name:
            return True
    return False

def list_of_maps_to_map_of_lists(list_of_maps):
    map_of_lists = {}
    for map_item in list_of_maps:
        for key, value in map_item.items():
            if key not in map_of_lists:
                map_of_lists[key] = []
            map_of_lists[key].append(value)
    return map_of_lists


class OuterLoopOptimizer(nn.Module):
    """
    Base class for outer loop optimization of reward weights.
    """
    def __init__(
        self,
        inner_loop,
        n_reward_components: int,
        learning_rate: float = 1e-3,
        share_extractor: bool = False,
        n_epochs: int = 4,
        batch_size: int = 128,
        device = 'cpu',
        save_frequency = 1000,
        save_path = 'logs/',
        log_episode_reward = 10,
        reward_L2 = 0.,
        weight_net_class = None,
        exploration_module = None,
        max_weight_history = 500,
        explore_choice = 'argmax',
        explore_tmp = 1,
        oracle_weight = [],
        reset_method = 'none',
        old_model_update=True,
        lr=1e-2,
        start_with_oracle=False,
        extractor_layer=2,
        weight_history='coeff',
        obs_space_tensor=None,
        use_gradient=True,
    ):
        super().__init__()
        self.inner_loop = inner_loop
        self.n_reward_components = n_reward_components
        self.learning_rate = learning_rate
        self.n_epochs = n_epochs
        self.batch_size = batch_size
        self.device = device
        self.save_frequency = save_frequency
        self.save_path = save_path
        self.log_episode_reward = log_episode_reward
        self.reward_components = inner_loop.reward_components
        self.reward_L2 = reward_L2
        self.exploration_module = exploration_module
        self.explore_choice = explore_choice
        self.explore_tmp = explore_tmp
        self.oracle_weight = oracle_weight
        self.last_task_performance = 0
        self.max_grad_norm = 0.5
        self.reset_method = reset_method
        self.old_model_update = old_model_update
        self.weight_history_type = weight_history
        self.obs_space_tensor = obs_space_tensor
        self.use_gradient = use_gradient

        # define weight network
        self.share_extractor = share_extractor
        if self.share_extractor:
            self.extractor = self.inner_loop.policy.features_extractor
        else:
            if extractor_layer == 2:
                self.extractor = nn.Sequential(
                    nn.Linear(self.inner_loop.policy.observation_space.shape[0], 128),
                    nn.ReLU(),
                    nn.Linear(128, self.inner_loop.policy.features_dim),
                ).to(self.device)
            else:
                self.extractor = nn.Sequential(
                    nn.Linear(self.inner_loop.policy.observation_space.shape[0], self.inner_loop.policy.features_dim),
                ).to(self.device)
        
        self.weight_net = weight_net_class(n_reward_components, self.inner_loop.policy.features_dim, self.device)

        small_lr_group = list([])
        large_lr_group = list([])
        if self.share_extractor or 'simple' in self.weight_net.name:
            self.all_params = list(self.weight_net.parameters())
        else:
            small_lr_group += list(self.extractor.parameters())
            self.all_params = list(self.extractor.parameters()) + list(self.weight_net.parameters()) # no exploration module
        if check_parameter_exists(self.weight_net, 'weight_net'):
            small_lr_group += list(self.weight_net.weight_net.parameters())
        if check_parameter_exists(self.weight_net, 'coeff'):
            large_lr_group.append(self.weight_net.coeff)
        self.optimizer = torch.optim.Adam(
            [
                {'params': large_lr_group, 'lr': lr*5},
                {'params': small_lr_group, 'lr': lr}
            ]
        )
        # self.scheduler = StepLR(self.optimizer, step_size=5, gamma=0.98)
        self.old_policy_gradient = None

        self.weight_history = deque(maxlen=max_weight_history)
        self.task_reward_history = deque(maxlen=max_weight_history)
        self.weight_gradient = 0.

        self.actor_params = actor_params = \
            list(self.inner_loop.policy.pi_features_extractor.parameters()) \
            + list(self.inner_loop.policy.mlp_extractor.policy_net.parameters()) \
            + list(self.inner_loop.policy.action_net.parameters()) # note: basically, pi_features_extractor is only FlattenExtractor
        
        if start_with_oracle:
            weight = torch.ones(self.n_reward_components).to(self.device)
            weight[1:] = torch.tensor(self.oracle_weight).to(self.device)
            self.weight_net.reset_weight(weight)

        if self.exploration_module is None:
            self.exploration_dim = n_reward_components
        else:
            self.exploration_dim = self.exploration_module.weight_dim

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.extractor(x)
        x = self.weight_net(x)
        return x
    
    def initialize(self):
        return
        
    @abstractmethod
    def optimize_weights(
        self,
    ) -> np.ndarray:
        pass
    
    def step(self) -> Tuple[np.ndarray, Dict]:
        # record reward weights
        if self.weight_history_type == 'coeff':
            x = torch.zeros(1)
            weights = self.weight_net.get_weight(x)

        elif self.weight_history_type == 'traj':
            data = self.inner_loop.rollout_buffer.sample(1000)
            with torch.no_grad():
                x = self.extractor(data.observations)
                weights = self.weight_net.get_train_data(x) # TODO: can be either forward() or get_weight()

        elif self.weight_history_type == 'space':
            data = self.obs_space_tensor
            with torch.no_grad():
                x = self.extractor(data)
                weights = self.weight_net.get_train_data(x)
            
        weights_np = weights.detach().cpu().numpy()
        self.weight_history.append(weights_np)
        for i in range(self.exploration_dim):
            self.inner_loop.logger.record(f"explore/get_weight_{i}", weights_np[0, i])

        self.task_reward_history.append(self.inner_loop.rollout_buffer.returns[..., 0].mean().item())

        # optimize weights
        if self.use_gradient:
            self.optimize_weights()
        

    def learn(
        self,
        total_timesteps: int,
        n_steps: int = 100,
        inner_loop_freq: int = 4,
        outer_loop_freq: int = 4,
        inner_loop_callback: Callable = None,
        progress_bar: bool = False,
        test_env = None,
        explore_frequency: int = 100,
        condition: str = 'none',
    ):
        iteration = 0
        last_inner_loop_update = 0
        last_outer_loop_update = 0

        self.initialize()

        total_timesteps, inner_loop_callback = self.inner_loop.pre_learn(total_timesteps, callback=inner_loop_callback, progress_bar=progress_bar)

        total_episodes = int(total_timesteps / n_steps)

        for i in range(total_episodes):
            if i == 0:
                self.monitor(test_env)

            # collect rollout data; train inner loop policy
            continue_training = self.inner_loop.on_learn(
                total_timesteps=total_timesteps,
                callback=inner_loop_callback,
                iteration=i,
            )
            if continue_training is False:
                break

            if outer_loop_freq > 0 and i % outer_loop_freq == outer_loop_freq-1:
                # when outer_loop_freq = -1, try algorithm 2
                self.step() # update outer-loop gradient value
                self.weight_net.clip_coeff()
            
            if i % self.save_frequency == 0:
                self.save()

            if  i % self.log_episode_reward == self.log_episode_reward - 1:
                self.monitor(test_env)

            # if i % (10*outer_loop_freq) == 10*outer_loop_freq-1: # every 50 steps, switch difficulty for CL env
            #     # curriculum learning
            #     self.inner_loop.env.env_method("switch_difficulty")
            # self.inner_loop.logger.record(f"cl/level", self.inner_loop.env.get_attr("cl_level")[0])

            if i % explore_frequency == explore_frequency-1:
                current_task_performance = self.current_task_performance
                P = 1 - current_task_performance
                delta_perf = current_task_performance - self.last_task_performance
                self.last_task_performance = current_task_performance

                self.inner_loop.logger.record(f"explore/task_perf", current_task_performance)
                
                if delta_perf < 0.05:
                    
                    do_random_sample = True

                    if outer_loop_freq == -1:

                        self.step() # update outer-loop gradient value
                        # if gradient is valid, do gradient update, do not random sample
                        if self.weight_gradient > 1e-2:
                            do_random_sample = False
                    
                    # else, random sample. can overwrite previous weight
                    if self.exploration_module is not None and do_random_sample:
                        if condition == 'none':
                            self.explore()
                        elif condition == 'performance':
                            if np.random.rand() < P:
                                self.inner_loop.logger.record(f"explore/set_weight", 1)
                                self.explore()
                            else:
                                self.inner_loop.logger.record(f"explore/set_weight", 0)

        if inner_loop_callback is not None:
            self.inner_loop.post_learn(callback=inner_loop_callback)

    def explore(self):
        # train exploration net
        weights = torch.tensor(np.array(self.weight_history), device=self.device, dtype=torch.float32).reshape(-1, self.exploration_dim)
        performance = torch.tensor(np.array(self.task_reward_history), device=self.device).reshape(-1)
        loss = self.exploration_module.train(weights, performance=performance)
        self.inner_loop.logger.record(f"train/weight_net_loss", loss)

        # find the best reward weight
        best_weight = self.exploration_module.explore(num_samples=1000, method=self.explore_choice, tmp=self.explore_tmp)
        best_weight_np = best_weight.detach().cpu().numpy()
        for i in range(self.exploration_dim):
            self.inner_loop.logger.record(f"explore/set_weight_{i}", best_weight_np[i])
        self.weight_net.reset_weight(best_weight)

        if self.reset_method == 'policy':
            self.inner_loop.policy.reset_policy()
        elif self.reset_method == 'action':
            self.inner_loop.policy.reset_action()
        elif self.reset_method == 'logstd':
            self.inner_loop.policy.reset_logstd()
        elif self.reset_method == 'policy_logstd':
            self.inner_loop.policy.reset_policy_logstd()
        elif self.reset_method == 'full':
            self.inner_loop.policy.reset_full_policy()
        
    def save(self):
        torch.save(self.state_dict(), self.save_path + '/outer_loop.pth')
        self.inner_loop.save(self.save_path + '/inner_loop')

    def monitor(self, env):
        # Evaluate the policy for one episode
        if env is None:
            env = self.inner_loop.env
        obs = env.reset()

        n_envs = obs.shape[0]
        done = np.zeros(n_envs, dtype=bool)
        task_reward = np.zeros(n_envs, dtype=np.float32)
        reward_components = [[] for _ in range(len(self.reward_components))]
        total_reward = np.zeros(n_envs, dtype=np.float32)
        total_weights = []
        episode_length = np.ones(n_envs, dtype=np.int32)

        while not done.all():
            # Convert observation to tensor
            obs_tensor = torch.as_tensor(obs).float().to(self.device)
            
            # Get action from policy
            with torch.no_grad():
                # Get reward weights for current observation
                reward_weights = self.forward(obs_tensor).cpu().numpy()
                total_weights.append(reward_weights)
                action, _ = self.inner_loop.policy.predict(obs, deterministic=True)
            
            # Step environment
            obs, rewards, done_, info = env.step(action)
            done = done | done_

            # Accumulate component-wise rewards
            task_reward += rewards
            info = list_of_maps_to_map_of_lists(info)
            for i, k in enumerate(self.reward_components):
                reward_components[i].append(info[k])
                total_reward += np.array(info[k])*np.array(reward_weights[:, i+1])
            total_reward += np.array(rewards)*np.array(reward_weights[:, 0])

            episode_length = np.where(~done, episode_length+1, episode_length)

        # Log rewards and weights
        total_weights = np.array(total_weights).mean(axis=0)
        reward_components = np.array(reward_components)
        for i in range(len(self.reward_components)+1):
            # self.inner_loop.logger.record(f"monitor/weight_{i}", total_weights[:, i])
            self.inner_loop.logger.record(f"monitor/weight_{i}_mean", total_weights[:, i].mean())
            self.inner_loop.logger.record(f"monitor/weight_{i}_std", total_weights[:, i].std())
            
        self.inner_loop.logger.record(f"monitor/task_reward", np.mean(task_reward))
        self.inner_loop.logger.record(f"monitor/total_reward", np.mean(total_reward))
        self.inner_loop.logger.record("monitor/episode_length", np.mean(episode_length))
        self.current_task_performance = np.mean(task_reward)

    def L2_reg(self):
        l2_reg = torch.tensor(0.).to(self.device)
        for param in self.all_params:
            l2_reg += torch.norm(param)
        return l2_reg

    def log_lr(self):
        for i, param_group in enumerate(self.optimizer.param_groups):
            self.inner_loop.logger.record(f"train/outer_loop_lr_{i}", param_group['lr'])

class SparseOuterLoop(OuterLoopOptimizer):
    def initialize(self):
        self.weight = torch.ones(self.n_reward_components).to(self.device)
        self.weight[1:] = 0

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.weight.expand(x.shape[0], -1)

    def optimize_weights(self):
        pass

class OracleOuterLoop(OuterLoopOptimizer):
    def initialize(self):
        self.weight = torch.ones(self.n_reward_components).to(self.device)
        self.weight[1:] = torch.tensor(self.oracle_weight).to(self.device)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.weight.expand(x.shape[0], -1)

    def optimize_weights(self):
        pass

class OracleOuterLoop2(OuterLoopOptimizer):
    def initialize(self):
        self.weight = torch.zeros(self.n_reward_components).to(self.device)
        self.weight[1:] = torch.tensor(self.oracle_weight).to(self.device)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.weight.expand(x.shape[0], -1)

    def optimize_weights(self):
        pass

class RandomOuterLoop(OuterLoopOptimizer):
    def initialize(self):
        self.weight = torch.ones(self.n_reward_components).to(self.device)
        self.weight[1:] = 0

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.weight.expand(x.shape[0], -1)

    def optimize_weights(self):
        self.weight = torch.rand_like(self.weight) * 2 - 1

class MetaGradientOuterLoop(OuterLoopOptimizer):
    def optimize_weights(self):
        # get rollout data
        total_loss = 0
        total_reward_L2 = 0
        total_meta_gradient = 0
        total_value = 0
        total_policy_gradient = 0
        last_loss = 1e5

        for epoch in range(self.n_epochs):
            final_grads = [torch.zeros_like(p) for p in self.all_params]
            batch_size = 128
            rollout_data = self.inner_loop.rollout_buffer.sample(batch_size)
            tmp_grad = 0

            for i in range(batch_size):
                obs_i = rollout_data.observations[i:i+1]
                act_i = rollout_data.actions[i:i+1]
                ret_i = rollout_data.returns[i:i+1]
                task_return_i = ret_i[..., 0]

                reward_weights_i = self.forward(obs_i)
                weighted_return_i = torch.sum(ret_i * reward_weights_i, dim=1)

                _, new_grad_theta = self.inner_loop.compute_policy_old_logprob(obs_i, act_i, False)
                _, old_grad_theta = self.inner_loop.compute_policy_old_logprob(obs_i, act_i, True)

                dot_theta = sum((g1 * g2).sum() for g1, g2 in zip(new_grad_theta, old_grad_theta))

                grad_phi_i = torch.autograd.grad(
                    outputs=weighted_return_i,
                    inputs=list(self.all_params),
                    grad_outputs=torch.ones_like(weighted_return_i),
                    retain_graph=True,
                    create_graph=True
                )

                multiplier = dot_theta * task_return_i  # scalar

                for j, g in enumerate(grad_phi_i):
                    final_grads[j] += multiplier * g

            final_grads = [g / batch_size for g in final_grads]
            for p, g in zip(self.all_params, final_grads):
                p.grad = -g
                tmp_grad += g.abs().sum().item()
            
            torch.nn.utils.clip_grad_norm_(self.all_params, self.max_grad_norm)
            self.optimizer.step()
            self.optimizer.zero_grad()

            if epoch > 2 and last_loss < tmp_grad:
                break
            last_loss = tmp_grad

        self.inner_loop.logger.record(f"train/outer_loop_loss", tmp_grad)
        self.inner_loop.logger.record(f"train/outer_loop_epoch", epoch)

        self.weight_gradient = tmp_grad
        # self.scheduler.step()
        self.log_lr()

class ImplicitFunctionOuterLoop(OuterLoopOptimizer):
    def initialize(self):
        self.max_Neumann_epochs = 50
        self.Neumann_alpha = 0.001


    def optimize_weights(self):
        total_grad = 0.
        total_reward_L2 = 0
        last_loss = 1e5
        loss_coeff = 1

        self.optimizer.zero_grad()

        # --- A. Outer grad (task reward) ---
        rollout_data = self.inner_loop.rollout_buffer.sample()
        J_task, adv = self.inner_loop.compute_policy_gradient(rollout_data, shaping=False, eval_old=False)
        outer_policy_grad = autograd.grad(J_task, self.actor_params, retain_graph=True)
        total_outer_grad_norm = torch.sqrt(sum(g.norm()**2 for g in outer_policy_grad)).item()

        # --- B. Inner grad (shaped reward) ---
        more_rollout_data = self.inner_loop.rollout_buffer_history.sample()
        J_shaped, _ = self.inner_loop.compute_policy_gradient(more_rollout_data, shaping=True, eval_old=False)
        # 注意：不标准化，不做 advantage normalization
        inner_policy_grad = autograd.grad(J_shaped, self.actor_params, create_graph=True, retain_graph=True)

        # --- C. Approx Hessian^-1 * outer_grad using CG ---
        def hessian_vector_product(loss, params, vec):
            grad = autograd.grad(loss, params, create_graph=True, retain_graph=True)
            Hv = autograd.grad(grad, params, grad_outputs=vec, retain_graph=True)
            return Hv

        def hessian_inverse_vector(loss, params, g, cg_iters=10, damping=1e-3):
            def flatten_vec(vec):
                return torch.cat([v.reshape(-1) for v in vec])
            def unflatten_vec(vec_flat, params_template):
                vecs = []
                pointer = 0
                for p in params_template:
                    numel = p.numel()
                    vecs.append(vec_flat[pointer:pointer+numel].reshape_as(p))
                    pointer += numel
                return vecs

            g_flat = flatten_vec(g)
            x = torch.zeros_like(g_flat)
            r = g_flat.clone()
            p = r.clone()
            rsold = torch.dot(r, r)

            for _ in range(cg_iters):
                p_vec = unflatten_vec(p, params)
                Hp_vec = hessian_vector_product(loss, params, p_vec)
                Hp = flatten_vec(Hp_vec)
                Hp += damping * p
                alpha = rsold / (torch.dot(p, Hp) + 1e-8)
                x += alpha * p
                r -= alpha * Hp
                rsnew = torch.dot(r, r)
                if torch.sqrt(rsnew) < 1e-8:
                    break
                p = r + (rsnew / rsold) * p
                rsold = rsnew

            return unflatten_vec(x, params)

        # compute v = H^-1 * outer_policy_grad
        shaped_grad_approx = hessian_inverse_vector(J_shaped, self.actor_params, outer_policy_grad, cg_iters=10, damping=1e-3)


        # --- D. Reward L2 (unchanged) ---
        reward_L2 = self.reward_L2 * self.L2_reg()
        total_reward_L2 = reward_L2.item()

        # --- Build main scalar S = sum_i <inner_policy_grad_i, shaped_grad_approx_i>
        S = None
        for g_theta, v in zip(inner_policy_grad, shaped_grad_approx):
            term = (g_theta * v).sum()
            S = term if S is None else S + term


        # # Total scalar to backprop into reward network
        # S_total = S + lambda_align * L_align + reward_L2
        S_total = S + reward_L2

        # Compute gradient wrt reward network parameters (explicit list)
        reward_params = list(self.weight_net.parameters())
        rgrads = autograd.grad(S_total, reward_params, retain_graph=True, allow_unused=True)

        # Map gradients back into self.all_params (if self.all_params includes reward params),
        # or directly assign to reward_params and call optimizer.step() on self.optimizer (which should optimize these)
        # Here we assign to reward_params and also to self.all_params positions if needed.
        # First, debug print
        for i, rg in enumerate(rgrads):
            if rg is None:
                # set zero grad to avoid None for optimizer
                rgrads[i] = torch.zeros_like(reward_params[i])
                # print(f"DBG: reward_grad for param {i} is None, set to zero")
            else:
                # optional: gradient clipping per-param
                if torch.isnan(rg).any():
                    # print("DBG: NaN in reward_grad, setting to zero")
                    rgrads[i] = torch.zeros_like(reward_params[i])

        # assign to reward_params' .grad
        for p, g in zip(reward_params, rgrads):
            p.grad = g.clone()

        # If self.all_params contains reward params (in same order), write them too.
        # Otherwise ensure self.optimizer is responsible for reward_params.
        try:
            # map reward params to self.all_params if possible
            for idx, p_all in enumerate(self.all_params):
                for rp_idx, rp in enumerate(reward_params):
                    if p_all is rp:
                        self.all_params[idx].grad = rgrads[rp_idx].clone()
        except Exception:
            pass

        # compute tmp_grad logging quantity (total abs)
        tmp_grad = sum(g.abs().sum().item() for g in rgrads)

        # gradient norm clipping for reward params (optional)
        torch.nn.utils.clip_grad_norm_(reward_params, self.max_grad_norm)
        # optional: also clip across self.all_params if you want
        torch.nn.utils.clip_grad_norm_(self.all_params, self.max_grad_norm)

        # --- F. Update reward weight ---
        if tmp_grad > 1e-8:
            self.optimizer.step()

        # --- G. Logging ---
        def flatten_grads(vec_list):
            return torch.cat([g.reshape(-1) for g in vec_list if g is not None])

        shaped_vec = flatten_grads(shaped_grad_approx)
        task_vec = flatten_grads(outer_policy_grad)

        cos_sim = torch.dot(shaped_vec, task_vec) / (shaped_vec.norm() * task_vec.norm() + 1e-8)

        self.inner_loop.logger.record(f"train/outer_loop_loss", tmp_grad)
        self.inner_loop.logger.record(f"train/outer_loop_reward_L2", total_reward_L2)
        self.inner_loop.logger.record(f"train/total_outer_grad_norm", total_outer_grad_norm)
        self.inner_loop.logger.record(f"train/Neumann_series_monitor", shaped_vec.norm().item())  # 用作近似 Hessian^-1 * g norm
        self.inner_loop.logger.record(f"train/cosine_similarity", cos_sim.item())
        self.inner_loop.logger.record(f"train/returns", adv.mean().item())
        # print(f"[Outer loop] tmp_grad={tmp_grad:.6f}, shaped_grad_norm={shaped_vec.norm():.6f}, cosine={cos_sim:.4f}")

        self.weight_gradient = tmp_grad
        # self.scheduler.step()
        self.log_lr()
    