from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Tuple, Callable
import numpy as np
import torch
import torch.nn as nn
from torch import tensor, float32, autograd
from collections import deque
from torch.optim.lr_scheduler import StepLR
from .weight_net import SimpleRewardWeightv1, SimpleRewardWeightv2, DirectRewardWeight, RegulatedRewardWeightv1, RegulatedRewardWeightv2, \
    RegulatedRewardWeightv3, RegulatedRewardWeightv4, OracleRewardWeight, RegulatedRewardWeightv5, RegulatedRewardWeightv6, RegulatedRewardWeightv7, RegulatedRewardWeightv8, \
    TanhWeightv1, TanhWeightv2, SoftmaxWeightv1, SoftmaxWeightv2
from .exploration import SinglePrediction, RandomExploration, MaxDistanceExploration

def get_weight_net(weight_type, exploration, explore_reward_dim, device, lower, upper):
    if weight_type == 'simplev1':
        weight_net_class = SimpleRewardWeightv1
    elif weight_type == 'simplev2':
        weight_net_class = SimpleRewardWeightv2
        lower = -1
    elif weight_type == 'tanhv1':
        weight_net_class = TanhWeightv1
    elif weight_type == 'tanhv2':
        weight_net_class = TanhWeightv2
        lower = -1
    elif weight_type == 'softmaxv1':
        weight_net_class = SoftmaxWeightv1
    elif weight_type == 'softmaxv2':
        weight_net_class = SoftmaxWeightv2
        lower = -1
    elif weight_type == 'direct':
        weight_net_class = DirectRewardWeight
    elif weight_type == 'regulatedv1':
        weight_net_class = RegulatedRewardWeightv1
    elif weight_type == 'regulatedv2':
        weight_net_class = RegulatedRewardWeightv2
        lower = -1
    elif weight_type == 'regulatedv3':
        weight_net_class = RegulatedRewardWeightv3
        total_timesteps=5_000_000
    elif weight_type == 'regulatedv4':
        weight_net_class = RegulatedRewardWeightv4
        lower = -1.0
    elif weight_type == 'regulatedv5':
        weight_net_class = RegulatedRewardWeightv5
        total_timesteps=5_000_000
    elif weight_type == 'regulatedv6':
        weight_net_class = RegulatedRewardWeightv6
        lower = -1.0
    elif weight_type == 'regulatedv7':
        weight_net_class = RegulatedRewardWeightv7
    elif weight_type == 'regulatedv8':
        weight_net_class = RegulatedRewardWeightv8
        lower = -1.0
    elif weight_type == 'oracle':
        weight_net_class = OracleRewardWeight
    else:
        raise ValueError(f"Weight type {weight_type} not supported")
    
    if exploration == 'none':
        exploration_module = None
    elif exploration == 'single':
        exploration_module = SinglePrediction(weight_dim=explore_reward_dim, hidden_dim=64, device=device, lower=lower, upper=upper)
    elif exploration == 'random':
        exploration_module = RandomExploration(weight_dim=explore_reward_dim, hidden_dim=64, device=device, lower=lower, upper=upper)
    elif exploration == 'maxdis':
        exploration_module = MaxDistanceExploration(weight_dim=explore_reward_dim, hidden_dim=64, device=device, lower=lower, upper=upper)
    else:
        raise ValueError(f"Exploration {exploration} not supported")
    return weight_net_class, exploration_module


def check_parameter_exists(model, parameter_name):
    for name, param in model.named_parameters():
        if name == parameter_name:
            return True
    return False

def list_of_maps_to_map_of_lists(list_of_maps):
    map_of_lists = {}
    for map_item in list_of_maps:
        for key, value in map_item.items():
            if key not in map_of_lists:
                map_of_lists[key] = []
            map_of_lists[key].append(value)
    return map_of_lists

class DummyOuterLoop(nn.Module):
    def __init__(self, num_objectives):
        super().__init__()
        self.num_objectives = num_objectives
    
    def get_weight(self, state):
        return torch.zeros(self.num_objectives, device=state.device)

class OuterLoop(nn.Module):
    """
    Base class for outer loop optimization of reward weights.
    """
    def __init__(
        self,
        n_reward_components: int,
        observation_dim: int = 48,
        learning_rate: float = 1e-3,
        share_extractor: bool = False,
        n_epochs: int = 4,
        batch_size: int = 128,
        device = 'cuda:0',
        save_frequency = 1000,
        save_path = 'logs/',
        log_episode_reward = 10,
        reward_L2 = 0.,
        weight_net_class = None,
        exploration = None,
        max_weight_history = 500,
        explore_choice = 'argmax',
        explore_tmp = 1,
        oracle_weight = [],
        reset_method = 'none',
        old_model_update=True,
        lr=1e-3,
        start_with_oracle=False,
        extractor_layer=2,
        weight_history='coeff',
        obs_space_tensor=None,
        use_gradient=True,
        condition='performance',
        reset_critic='none',
    ):
        super().__init__()
        self.num_objectives = n_reward_components
        self.n_reward_components = n_reward_components
        self.learning_rate = learning_rate
        self.n_epochs = n_epochs
        self.batch_size = batch_size
        self.device = device
        self.save_frequency = save_frequency
        self.save_path = save_path
        self.log_episode_reward = log_episode_reward
        self.reward_L2 = reward_L2
        # 
        self.explore_choice = explore_choice
        self.explore_tmp = explore_tmp
        self.oracle_weight = oracle_weight
        self.last_task_performance = 0
        self.max_grad_norm = 0.5
        self.reset_method = reset_method
        self.reset_critic = reset_critic
        self.old_model_update = old_model_update
        self.weight_history_type = weight_history
        self.obs_space_tensor = obs_space_tensor
        self.use_gradient = use_gradient

        # define weight network
        if extractor_layer == 2:
            self.extractor = nn.Sequential(
                nn.Linear(observation_dim, 128),
                nn.ReLU(),
                nn.Linear(128, 128),
            ).to(self.device)
        else:
            self.extractor = nn.Sequential(
                nn.Linear(observation_dim, 128),
            ).to(self.device)
        weight_net_class, exploration_module = get_weight_net(weight_net_class, exploration, self.num_objectives, device, 0, 1)
        self.exploration_module = exploration_module
        self.weight_net = weight_net_class(n_reward_components, 128, self.device)

        self.weight_history = deque(maxlen=max_weight_history)
        self.task_reward_history = deque(maxlen=max_weight_history)
        self.weight_gradient = 0.

        if start_with_oracle:
            weight = torch.zeros(self.n_reward_components).to(self.device)
            weight[1:] = torch.tensor(self.oracle_weight).to(self.device)
            self.weight_net.reset_weight(weight)

        if self.exploration_module is None:
            self.exploration_dim = n_reward_components
        else:
            self.exploration_dim = self.exploration_module.weight_dim

        self.condition = condition

    def set_inner_loop(self, inner_loop, writer):
        """
        Set the inner loop for the outer loop.
        :param inner_loop: The inner loop to be set.
        """
        self.writer = writer
        self.inner_loop = inner_loop

        self.actor_params = actor_params = \
            list(self.inner_loop.policy.actor.parameters())

        small_lr_group = list([])
        large_lr_group = list([])
        if self.weight_net.name == 'simple':
            self.all_params = list(self.weight_net.parameters())
        else:
            small_lr_group += list(self.extractor.parameters())
            self.all_params = list(self.extractor.parameters()) + list(self.weight_net.parameters()) # no exploration module
        if check_parameter_exists(self.weight_net, 'weight_net'):
            small_lr_group += list(self.weight_net.weight_net.parameters())
        if check_parameter_exists(self.weight_net, 'coeff'):
            large_lr_group.append(self.weight_net.coeff)
        lr=1e-3
        self.optimizer = torch.optim.Adam(
            [
                {'params': large_lr_group, 'lr': lr*5},
                {'params': small_lr_group, 'lr': lr}
            ]
        )
        self.scheduler = StepLR(self.optimizer, step_size=5, gamma=0.98)
        self.old_policy_gradient = None

        self.initialize()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.extractor(x)
        x = self.weight_net(x)
        return x
    
    def initialize(self):
        return
        
    def optimize_weights(self):
        pass
    
    def step(self, explore=False, success_buffer=None, it=0) -> Tuple[np.ndarray, Dict]:
        '''
        records history weights & does gradient update & explores
        '''
        self.iter = it
        # record history reward weights
        if self.weight_history_type == 'coeff':
            x = torch.zeros(1)
            weights = self.weight_net.get_weight(x)

        elif self.weight_history_type == 'traj':
            data = self.inner_loop.storage.sample_obs(1000)
            with torch.no_grad():
                x = self.extractor(data)
                weights = self.weight_net.get_train_data(x) # TODO: can be either forward() or get_weight()

        elif self.weight_history_type == 'space':
            data = self.obs_space_tensor
            with torch.no_grad():
                x = self.extractor(data)
                weights = self.weight_net.get_train_data(x)
            
        weights_np = weights.detach().cpu().numpy()
        self.weight_history.append(weights_np)
        self.task_reward_history.append(success_buffer)

        current_task_performance = np.mean(success_buffer)
        P = 1 - current_task_performance
        delta_perf = current_task_performance - self.last_task_performance
        self.last_task_performance = current_task_performance

        # optimize weights
        if self.use_gradient:
            if delta_perf < 0.05:
                self.optimize_weights()
                # clip weights
                self.weight_net.clip_coeff()

        # explore
        if explore and self.exploration_module is not None:
            if self.condition == 'none':
                self.explore()
            elif self.condition == 'performance':
                self.writer.add_scalar(f"explore/task_perf", current_task_performance, self.iter)
                if delta_perf < 0.05 and np.random.rand() < P:
                    self.writer.add_scalar(f"explore/set_weight", 1, self.iter)
                    self.explore()
                else:
                    self.writer.add_scalar(f"explore/set_weight", 0, self.iter)
                    
            return True
        return False


    def explore(self):
        # train exploration net
        weights = torch.tensor(np.array(self.weight_history), device=self.device, dtype=torch.float32).reshape(-1, self.exploration_dim) # [55, 10]
        performance = torch.tensor(np.array(self.task_reward_history), device=self.device).reshape(-1) # [55]
        loss = self.exploration_module.train(weights, performance=performance)
        self.writer.add_scalar(f"outer_loop/weight_net_loss", loss, self.iter)

        # find the best reward weight
        best_weight = self.exploration_module.explore(num_samples=1000, method=self.explore_choice, tmp=self.explore_tmp)
        best_weight_np = best_weight.detach().cpu().numpy()
        for i in range(self.exploration_dim):
            self.writer.add_scalar(f"explore/set_weight_{i}", best_weight_np[i], self.iter)
        self.weight_net.reset_weight(best_weight)

        if self.reset_method == 'logstd':
            self.inner_loop.policy.reset_logstd()
        elif self.reset_method == 'last_layer':
            self.inner_loop.policy.reset_last_layer()
        elif self.reset_method == 'actor':
            self.inner_loop.policy.reset_actor()
        elif self.reset_method == 'full_policy':
            self.inner_loop.policy.reset_full_policy()

        if self.reset_critic == 'critic':
            self.inner_loop.policy.reset_critic()
        elif self.reset_critic == 'last_layer':
            self.inner_loop.policy.reset_critic_last_layer()

    def save(self):
        torch.save(self.state_dict(), self.save_path + '/outer_loop.pth')
        self.inner_loop.save(self.save_path + '/inner_loop')

    def L2_reg(self):
        l2_reg = torch.tensor(0.).to(self.device)
        for param in self.all_params:
            l2_reg += torch.norm(param)
        return l2_reg

    def log_lr(self):
        for i, param_group in enumerate(self.optimizer.param_groups):
            self.writer.add_scalar(f"outer_loop/lr_{i}", param_group['lr'], self.iter)

class SparseOuterLoop(OuterLoop):
    '''
    task success reward only. 
    '''
    def initialize(self):
        self.weight = torch.ones(self.n_reward_components).to(self.device)
        self.weight[1:] = 0

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.weight.expand(x.shape[0], -1)

    def optimize_weights(self):
        pass

class OracleOuterLoop(OuterLoop):
    '''
    task success weight is 1, and all other weights are set to the oracle weights.
    '''
    def initialize(self):
        self.weight = torch.ones(self.n_reward_components).to(self.device)
        self.weight[1:] = torch.tensor(self.oracle_weight).to(self.device)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.weight.expand(x.shape[0], -1)

    def optimize_weights(self):
        pass

class OracleOuterLoop2(OuterLoop):
    '''
    task success weight is 0, and all other weights are set to the oracle weights.
    should be the same as original performance. 
    '''
    def initialize(self):
        self.weight = torch.zeros(self.n_reward_components).to(self.device)
        self.weight[1:] = torch.tensor(self.oracle_weight).to(self.device)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.weight.expand(x.shape[0], -1)

    def optimize_weights(self):
        pass

class RandomOuterLoop(OuterLoop):
    '''
    randomly updates the weights, replaces gradient update. 
    '''
    def initialize(self):
        self.weight = torch.ones(self.n_reward_components).to(self.device)
        self.weight[1:] = 0

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.weight.expand(x.shape[0], -1)

    def optimize_weights(self):
        self.weight = torch.rand_like(self.weight) * 2 - 1

class MetaGradientOuterLoop(OuterLoop):
    def optimize_weights(self):
        last_loss = 1e5

        for epoch in range(self.n_epochs):
            final_grads = [torch.zeros_like(p) for p in self.all_params]
            batch_size = 128
            rollout_data = self.inner_loop.rollout_buffer.sample(batch_size)
            tmp_grad = 0

            for i in range(batch_size):
                obs_i = rollout_data.observations[i:i+1]
                act_i = rollout_data.actions[i:i+1]
                ret_i = rollout_data.returns[i:i+1]
                task_return_i = ret_i[..., 0]

                reward_weights_i = self.forward(obs_i)
                weighted_return_i = torch.sum(ret_i * reward_weights_i, dim=1)

                _, new_grad_theta = self.inner_loop.compute_policy_old_logprob(obs_i, act_i, False)
                _, old_grad_theta = self.inner_loop.compute_policy_old_logprob(obs_i, act_i, True)

                dot_theta = sum((g1 * g2).sum() for g1, g2 in zip(new_grad_theta, old_grad_theta))

                grad_phi_i = torch.autograd.grad(
                    outputs=weighted_return_i,
                    inputs=list(self.all_params),
                    grad_outputs=torch.ones_like(weighted_return_i),
                    retain_graph=True,
                    create_graph=True
                )

                multiplier = dot_theta * task_return_i  # scalar

                for j, g in enumerate(grad_phi_i):
                    final_grads[j] += multiplier * g

            final_grads = [g / batch_size for g in final_grads]
            for p, g in zip(self.all_params, final_grads):
                p.grad = -g
                tmp_grad += g.abs().sum().item()
            
            torch.nn.utils.clip_grad_norm_(self.all_params, self.max_grad_norm)
            self.optimizer.step()
            self.optimizer.zero_grad()

            if epoch > 2 and last_loss < tmp_grad:
                break
            last_loss = tmp_grad

        self.writer.add_scalar(f"outer_loop/loss", tmp_grad, self.iter)
        self.writer.add_scalar(f"outer_loop/epoch", epoch, self.iter)

        self.weight_gradient = tmp_grad
        self.scheduler.step()
        self.log_lr()

class ImplicitFunctionOuterLoop(OuterLoop):
    def initialize(self):
        self.Neumann_epochs = 10
        self.Neumann_alpha = 0.001

        self.policy_grad_buffer = [torch.zeros_like(p, requires_grad=False).to(self.device) for p in self.actor_params]
        self.policy_grad_buffer2 = [torch.zeros_like(p, requires_grad=False).to(self.device) for p in self.actor_params]
    
    def optimize_weights(self):
        total_grad = 0.
        total_reward_L2 = 0
        last_loss = 1e5
        loss_coeff = 1

        self.optimizer.zero_grad()

        # --- A. Outer grad (task reward) ---
        # rollout_data = self.inner_loop.rollout_buffer.sample()
        J_task, adv = self.inner_loop.compute_gradient(shaping=False)
        outer_policy_grad = autograd.grad(J_task, self.actor_params, retain_graph=True)
        total_outer_grad_norm = torch.sqrt(sum(g.norm()**2 for g in outer_policy_grad)).item()

        # --- B. Inner grad (shaped reward) ---
        # more_rollout_data = self.inner_loop.rollout_buffer_history.sample()
        J_shaped, _ = self.inner_loop.compute_gradient(shaping=True, weight_net_forward=self.forward)
        # 注意：不标准化，不做 advantage normalization
        inner_policy_grad = autograd.grad(J_shaped, self.actor_params, create_graph=True, retain_graph=True)

        # --- C. Approx Hessian^-1 * outer_grad using CG ---
        def hessian_vector_product(loss, params, vec):
            grad = autograd.grad(loss, params, create_graph=True, retain_graph=True)
            Hv = autograd.grad(grad, params, grad_outputs=vec, retain_graph=True)
            return Hv

        def hessian_inverse_vector(loss, params, g, cg_iters=10, damping=1e-3):
            def flatten_vec(vec):
                return torch.cat([v.reshape(-1) for v in vec])
            def unflatten_vec(vec_flat, params_template):
                vecs = []
                pointer = 0
                for p in params_template:
                    numel = p.numel()
                    vecs.append(vec_flat[pointer:pointer+numel].reshape_as(p))
                    pointer += numel
                return vecs

            g_flat = flatten_vec(g)
            x = torch.zeros_like(g_flat)
            r = g_flat.clone()
            p = r.clone()
            rsold = torch.dot(r, r)

            for _ in range(cg_iters):
                p_vec = unflatten_vec(p, params)
                Hp_vec = hessian_vector_product(loss, params, p_vec)
                Hp = flatten_vec(Hp_vec)
                Hp += damping * p
                alpha = rsold / (torch.dot(p, Hp) + 1e-8)
                x += alpha * p
                r -= alpha * Hp
                rsnew = torch.dot(r, r)
                if torch.sqrt(rsnew) < 1e-8:
                    break
                p = r + (rsnew / rsold) * p
                rsold = rsnew

            return unflatten_vec(x, params)

        # compute v = H^-1 * outer_policy_grad
        shaped_grad_approx = hessian_inverse_vector(J_shaped, self.actor_params, outer_policy_grad, cg_iters=10, damping=1e-3)

        # --- D. Reward L2 (unchanged) ---
        reward_L2 = self.reward_L2 * self.L2_reg()
        total_reward_L2 = reward_L2.item()

        # --- E. Compute gradient w.r.t reward weight with alignment regularizer ---
        # shaped_grad_approx: list of tensors (approx H^-1 * outer_policy_grad)
        # inner_policy_grad: list of tensors (d J_shaped / d theta), create_graph=True computed earlier
        # outer_policy_grad: list of tensors (d J_task / d theta)

        # --- Build main scalar S = sum_i <inner_policy_grad_i, shaped_grad_approx_i>
        S = None
        for g_theta, v in zip(inner_policy_grad, shaped_grad_approx):
            term = (g_theta * v).sum()
            S = term if S is None else S + term

        # Total scalar to backprop into reward network
        S_total = S #+ lambda_align * L_align + reward_L2

        # Compute gradient wrt reward network parameters (explicit list)
        reward_params = list(self.weight_net.parameters())
        rgrads = autograd.grad(S_total, reward_params, retain_graph=True, allow_unused=True)

        # Map gradients back into self.all_params (if self.all_params includes reward params),
        # or directly assign to reward_params and call optimizer.step() on self.optimizer (which should optimize these)
        # Here we assign to reward_params and also to self.all_params positions if needed.
        # First, debug print
        for i, rg in enumerate(rgrads):
            if rg is None:
                # set zero grad to avoid None for optimizer
                rgrads[i] = torch.zeros_like(reward_params[i])
                # print(f"DBG: reward_grad for param {i} is None, set to zero")
            else:
                # optional: gradient clipping per-param
                if torch.isnan(rg).any():
                    # print("DBG: NaN in reward_grad, setting to zero")
                    rgrads[i] = torch.zeros_like(reward_params[i])

        # assign to reward_params' .grad
        for p, g in zip(reward_params, rgrads):
            p.grad = g.clone()

        # If self.all_params contains reward params (in same order), write them too.
        # Otherwise ensure self.optimizer is responsible for reward_params.
        try:
            # map reward params to self.all_params if possible
            for idx, p_all in enumerate(self.all_params):
                for rp_idx, rp in enumerate(reward_params):
                    if p_all is rp:
                        self.all_params[idx].grad = rgrads[rp_idx].clone()
        except Exception:
            pass

        # compute tmp_grad logging quantity (total abs)
        tmp_grad = sum(g.abs().sum().item() for g in rgrads)

        # gradient norm clipping for reward params (optional)
        torch.nn.utils.clip_grad_norm_(reward_params, self.max_grad_norm)
        # optional: also clip across self.all_params if you want
        torch.nn.utils.clip_grad_norm_(self.all_params, self.max_grad_norm)

        # --- F. Update reward weight ---
        if tmp_grad > 1e-2:
            self.optimizer.step()

        # --- G. Logging ---
        def flatten_grads(vec_list):
            return torch.cat([g.reshape(-1) for g in vec_list if g is not None])

        shaped_vec = flatten_grads(shaped_grad_approx)
        task_vec = flatten_grads(outer_policy_grad)

        cos_sim = torch.dot(shaped_vec, task_vec) / (shaped_vec.norm() * task_vec.norm() + 1e-8)

        self.writer.add_scalar(f"train/outer_loop_loss", tmp_grad, self.iter)
        self.writer.add_scalar(f"train/outer_loop_reward_L2", total_reward_L2, self.iter)
        self.writer.add_scalar(f"train/total_outer_grad_norm", total_outer_grad_norm, self.iter)
        self.writer.add_scalar(f"train/Neumann_series_monitor", shaped_vec.norm().item(), self.iter)  # 用作近似 Hessian^-1 * g norm
        self.writer.add_scalar(f"train/cosine_similarity", cos_sim.item(), self.iter)
        self.writer.add_scalar(f"train/returns", adv.mean().item(), self.iter)

        self.weight_gradient = tmp_grad
        self.scheduler.step()
        self.log_lr()
    