import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR
from utils.logger import logger
from utils.neural_networks import DeterministicNN_IQN
from agents.diffusion import Diffusion
from agents.model import MLP
from dataset.experience_replay import ExperienceReplay
from torch.distributions import uniform
from sklearn.neighbors import KDTree
from agents.helpers import EMA
import numpy as np


class RADAC_Neutral:
    """
    RADAC Neutral: RADAC with expected value (mean) instead of CVaR.
    
    This is an ablation version where:
    - Double Distributional Critic (IQN) is kept the same
    - Actor update uses expected value (mean) instead of CVaR
    - Q-learning style: maximize E[Q] instead of CVaR[Q]
    - All other components (BC loss, EMA, target updates) remain the same
    """
    def __init__(self,
                 state_dim,
                 action_dim,
                 max_action,
                 device,
                 discount,
                 tau,
                 n_quantiles,
                 embedding_dim,
                 beta_schedule='linear',
                 n_timesteps=100,
                 lr_actor=3e-4,
                 lr_critic=3e-4,
                 ema_decay=0.995,
                 lr_maxt=1000,
                 grad_norm=1.0,
                 eta=1.0,
                 lr_decay=False,
                 step_start_ema=1000,
                 update_ema_every=5,
                 q_clip_range=None,  # e.g. (-300, 300)
                 lambda_bc=1.0,      # Weight for BC loss
                 eta_warmup_steps=0,  # Warmup steps to keep eta=0
                 eta_ramp_steps=1e5
                 ):
        self.device = device
        self.discount = discount
        self.tau = tau
        self.n_quantiles = n_quantiles
        self.grad_norm = grad_norm
        self.eta = eta
        self.step = 0
        self.lr_decay = lr_decay
        self.q_clip_range = q_clip_range
        self.lambda_bc = lambda_bc
        self.eta_warmup_steps = eta_warmup_steps
        self.eta_ramp_steps = eta_ramp_steps

        # ----------------------------
        # 1. Actor (Diffusion) のセットアップ
        # ----------------------------
        self.policy_model = MLP(
            state_dim=state_dim,
            action_dim=action_dim,
            device=device
        )
        self.actor = Diffusion(
            state_dim=state_dim,
            action_dim=action_dim,
            model=self.policy_model,
            max_action=max_action,
            beta_schedule=beta_schedule,
            n_timesteps=n_timesteps
        ).to(device)

        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=lr_actor)
        if lr_decay:
            self.actor_lr_scheduler = CosineAnnealingLR(
                self.actor_optimizer, T_max=lr_maxt, eta_min=0.)

        # EMA for actor
        self.ema = EMA(ema_decay)
        self.ema_model = copy.deepcopy(self.actor)
        self.step_start_ema = step_start_ema
        self.update_ema_every = update_ema_every

        # ----------------------------
        # 2. Double Critic (IQN) と Target Critic
        # ----------------------------
        self.critic1 = DeterministicNN_IQN(
            dim_state=state_dim,
            dim_action=action_dim,
            layers_state=[256],
            layers_action=[256],
            layers_f=[256],
            embedding_dim=embedding_dim,
            tau_embed_dim=n_quantiles
        ).to(device)
        self.critic2 = DeterministicNN_IQN(
            dim_state=state_dim,
            dim_action=action_dim,
            layers_state=[256],
            layers_action=[256],
            layers_f=[256],
            embedding_dim=embedding_dim,
            tau_embed_dim=n_quantiles
        ).to(device)

        self.critic1_target = copy.deepcopy(self.critic1)
        self.critic2_target = copy.deepcopy(self.critic2)

        self.critic1_optimizer = torch.optim.Adam(self.critic1.parameters(), lr=lr_critic)
        self.critic2_optimizer = torch.optim.Adam(self.critic2.parameters(), lr=lr_critic)

        if lr_decay:
            self.critic1_lr_scheduler = CosineAnnealingLR(self.critic1_optimizer, T_max=lr_maxt, eta_min=0.)
            self.critic2_lr_scheduler = CosineAnnealingLR(self.critic2_optimizer, T_max=lr_maxt, eta_min=0.)

        # タウ分布 (Uniform for expected value calculation)
        self.distr_taus_uniform = uniform.Uniform(0., 1.)

        self.max_action = max_action
        self.action_dim = action_dim
        self.last_eps_act = None
        self.last_eps_act_history = []

    # ----------------------------
    # KDTree 構築ユーティリティ
    # ----------------------------
    def build_kdtree(self, action_array, kappa: float = 3.0):
        """
        Parameters
        ----------
        action_array : np.ndarray, shape (N, act_dim)
            行動データセット (±1 スケールで揃えること)
        """
        self.kdtree = KDTree(action_array, leaf_size=40)

        # 自身以外の最近傍距離のメディアン
        self.global_med_dist = np.median(
            self.kdtree.query(action_array, k=2)[0][:, 1]
        )
        self.kappa = kappa

    def train(self, replay_buffer: ExperienceReplay, iterations, batch_size=100, log_writer=None):
        metric = {
            'bc_loss': [],
            'actor_loss': [],
            'critic_loss': [],
            'q_mean_val': [],  # Expected value (instead of CVaR)
            'Q_mean': []
        }

        for _ in range(iterations):
            # ================
            # 1. Sample batch
            # ================
            state, action, next_state, reward, not_done = replay_buffer.sample(batch_size)

            # ================
            # 2. Critic update (Double Critic) - Same as RADAC
            # ================
            tau_k  = self.distr_taus_uniform.sample((self.n_quantiles, 1)).to(self.device)
            tau_k_ = self.distr_taus_uniform.sample((self.n_quantiles, 1)).to(self.device)

            current_q1 = self.critic1.get_sampled_Z(state, tau_k, action)
            current_q2 = self.critic2.get_sampled_Z(state, tau_k, action)

            with torch.no_grad():
                # next_action は EMAアクターからサンプル
                next_action = self.ema_model.sample(next_state)

                target_q1 = self.critic1_target.get_sampled_Z(next_state, tau_k_, next_action)
                target_q2 = self.critic2_target.get_sampled_Z(next_state, tau_k_, next_action)

                reward = reward.view(batch_size, 1).expand(batch_size, self.n_quantiles)
                not_done = not_done.view(batch_size, 1).expand(batch_size, self.n_quantiles)

                # minターゲットでCriticの過大推定を抑制
                target_min = torch.min(target_q1, target_q2)
                target_min = reward + not_done * self.discount * target_min

                # クリップオプション
                if self.q_clip_range is not None:
                    low_clip, high_clip = self.q_clip_range
                    target_min = torch.clamp(target_min, min=low_clip, max=high_clip)

            loss1 = self.quantile_huber_loss(target_min, current_q1, tau_k)
            loss2 = self.quantile_huber_loss(target_min, current_q2, tau_k)
            critic_loss = 0.5 * (loss1 + loss2)

            # Optimize Critic1
            self.critic1_optimizer.zero_grad()
            loss1.backward()
            if self.grad_norm > 0:
                nn.utils.clip_grad_norm_(self.critic1.parameters(), max_norm=self.grad_norm, norm_type=2)
            self.critic1_optimizer.step()
            if self.lr_decay:
                self.critic1_lr_scheduler.step()

            # Optimize Critic2
            self.critic2_optimizer.zero_grad()
            loss2.backward()
            if self.grad_norm > 0:
                nn.utils.clip_grad_norm_(self.critic2.parameters(), max_norm=self.grad_norm, norm_type=2)
            self.critic2_optimizer.step()
            if self.lr_decay:
                self.critic2_lr_scheduler.step()

            # ================
            # 3. Actor update (Expected Value instead of CVaR)
            # ================
            sampled_action = self.actor(state)

            # BC Loss
            bc_loss = self.actor.loss(action, state)

            # Expected value calculation (instead of CVaR)
            # Use uniform distribution over [0, 1] to compute mean Q-value
            tau_actor = self.distr_taus_uniform.sample((self.n_quantiles, 1)).to(self.device)
            actor_q1 = self.critic1.get_sampled_Z(state, tau_actor, sampled_action)
            actor_q2 = self.critic2.get_sampled_Z(state, tau_actor, sampled_action)

            # Use min of two critics (same as RADAC)
            actor_q = torch.min(actor_q1, actor_q2)
            q_mean_val = actor_q.mean()  # Expected value (mean) instead of CVaR

            # ウォームアップ期間中は eta=0 にする
            if self.step < self.eta_warmup_steps:
                local_eta = 0.0
            else:
                denom = max(1.0, float(self.eta_ramp_steps))
                progress = min(1.0, (self.step - self.eta_warmup_steps) / denom)
                local_eta = self.eta * progress

            # Maximize expected Q-value (Q-learning style)
            q_loss = - q_mean_val

            # BCロスにも係数をつける
            actor_loss = self.lambda_bc * bc_loss + local_eta * q_loss

            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            if self.grad_norm > 0:
                nn.utils.clip_grad_norm_(self.actor.parameters(), max_norm=self.grad_norm, norm_type=2)
            self.actor_optimizer.step()
            if self.lr_decay:
                self.actor_lr_scheduler.step()

            # ─────────────────────────────
            # ε_act (OOD 行動率) を計算・保存・ログ
            # ─────────────────────────────
            if hasattr(self, 'kdtree'):
                with torch.no_grad():
                    sa = sampled_action.cpu().numpy()  # (B, act_dim)
                    dist, _ = self.kdtree.query(sa, k=1)
                    eps_act_batch = (
                        dist[:, 0] > self.kappa * self.global_med_dist
                    ).mean().item()

                self.last_eps_act = eps_act_batch
                self.last_eps_act_history.append(eps_act_batch)

                if log_writer is not None:
                    log_writer.add_scalar("eps_act", eps_act_batch, self.step)

            # ================
            # 4. Update EMAモデル
            # ================
            if self.step >= self.step_start_ema and self.step % self.update_ema_every == 0:
                self.ema.update_model_average(self.ema_model, self.actor)

            # ================
            # 5. Target Criticのソフト更新
            # ================
            for param, target_param in zip(self.critic1.parameters(), self.critic1_target.parameters()):
                target_param.data.copy_(
                    self.tau * param.data + (1 - self.tau) * target_param.data
                )
            for param, target_param in zip(self.critic2.parameters(), self.critic2_target.parameters()):
                target_param.data.copy_(
                    self.tau * param.data + (1 - self.tau) * target_param.data
                )

            self.step += 1

            # ================
            # 6. Logging
            # ================
            current_q_mean = 0.5 * (current_q1.mean().item() + current_q2.mean().item())

            if log_writer is not None:
                if self.grad_norm > 0:
                    actor_grad_norm = nn.utils.clip_grad_norm_(self.actor.parameters(), max_norm=self.grad_norm, norm_type=2)
                    critic1_grad_norm = nn.utils.clip_grad_norm_(self.critic1.parameters(), max_norm=self.grad_norm, norm_type=2)
                    critic2_grad_norm = nn.utils.clip_grad_norm_(self.critic2.parameters(), max_norm=self.grad_norm, norm_type=2)

                    log_writer.add_scalar('Actor Grad Norm', actor_grad_norm.max().item(), self.step)
                    log_writer.add_scalar('Critic1 Grad Norm', critic1_grad_norm.max().item(), self.step)
                    log_writer.add_scalar('Critic2 Grad Norm', critic2_grad_norm.max().item(), self.step)

                log_writer.add_scalar('BC Loss', bc_loss.item(), self.step)
                log_writer.add_scalar('Actor Loss', actor_loss.item(), self.step)
                log_writer.add_scalar('Critic Loss', critic_loss.item(), self.step)
                log_writer.add_scalar('Q Mean Val (Expected Value)', q_mean_val.item(), self.step)
                log_writer.add_scalar('Q Mean', current_q_mean, self.step)

            metric['bc_loss'].append(bc_loss.item())
            metric['actor_loss'].append(actor_loss.item())
            metric['critic_loss'].append(critic_loss.item())
            metric['q_mean_val'].append(q_mean_val.item())
            metric['Q_mean'].append(current_q_mean)

        return metric

    def quantile_huber_loss(self, target, current, tau_k):
        batch_size, num_quantiles = target.size()
        target_ = target.unsqueeze(2).expand(batch_size, -1, num_quantiles)
        current_ = current.unsqueeze(1).expand(batch_size, num_quantiles, -1)
        tau_ = tau_k.unsqueeze(0).expand(batch_size, num_quantiles, num_quantiles)
        huber_loss = F.smooth_l1_loss(current_, target_, reduction='none')
        quantile_loss = torch.abs(tau_ - (target_ - current_).detach().le(0).float()) * huber_loss
        return quantile_loss.sum(dim=1).mean()

    def sample_action(self, state):
        if not isinstance(state, torch.Tensor):
            state = torch.FloatTensor(state).to(self.device)
        else:
            state = state.to(self.device)
        state = state.reshape(1, -1)

        with torch.no_grad():
            action = self.actor.sample(state)
        return action.cpu().data.numpy().flatten()

    def save_model(self, dir, id=None):
        if id is not None:
            torch.save(self.actor.state_dict(), f'{dir}/actor_{id}.pth')
            torch.save(self.critic1.state_dict(), f'{dir}/critic1_{id}.pth')
            torch.save(self.critic2.state_dict(), f'{dir}/critic2_{id}.pth')
        else:
            torch.save(self.actor.state_dict(), f'{dir}/actor.pth')
            torch.save(self.critic1.state_dict(), f'{dir}/critic1.pth')
            torch.save(self.critic2.state_dict(), f'{dir}/critic2.pth')

    def load_model(self, dir, id=None):
        if id is not None:
            self.actor.load_state_dict(torch.load(f'{dir}/actor_{id}.pth'))
            self.critic1.load_state_dict(torch.load(f'{dir}/critic1_{id}.pth'))
            self.critic2.load_state_dict(torch.load(f'{dir}/critic2_{id}.pth'))
        else:
            self.actor.load_state_dict(torch.load(f'{dir}/actor.pth'))
            self.critic1.load_state_dict(torch.load(f'{dir}/critic1.pth'))
            self.critic2.load_state_dict(torch.load(f'{dir}/critic2.pth'))

