import copy
import math
import einops
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal, TransformedDistribution, constraints

from torch.distributions.transforms import Transform

from diffusers.schedulers.scheduling_ddpm import DDPMScheduler, DDPMSchedulerOutput
from typing import Optional

class TanhTransform(Transform):
    r"""
    Transform via the mapping :math:`y = \tanh(x)`.
    It is equivalent to
    ```
    ComposeTransform([AffineTransform(0., 2.), SigmoidTransform(), AffineTransform(-1., 2.)])
    ```
    However this might not be numerically stable, thus it is recommended to use `TanhTransform`
    instead.
    Note that one should use `cache_size=1` when it comes to `NaN/Inf` values.
    """
    domain = constraints.real
    codomain = constraints.interval(-1.0, 1.0)
    bijective = True
    sign = +1

    @staticmethod
    def atanh(x):
        return 0.5 * (x.log1p() - (-x).log1p())

    def __eq__(self, other):
        return isinstance(other, TanhTransform)

    def _call(self, x):
        return x.tanh()

    def _inverse(self, y):
        # We do not clamp to the boundary here as it may degrade the performance of certain algorithms.
        # one should use `cache_size=1` instead
        return self.atanh(y)

    def log_abs_det_jacobian(self, x, y):
        # We use a formula that is more numerically stable, see details in the following link
        # XXXX
        return 2. * (math.log(2.) - x - F.softplus(-2. * x))


class MLPNetwork(nn.Module):
    
    def __init__(self, input_dim, output_dim, hidden_size=256):
        super(MLPNetwork, self).__init__()
        self.network = nn.Sequential(
                        nn.Linear(input_dim, hidden_size),
                        nn.ReLU(),
                        nn.Linear(hidden_size, hidden_size),
                        nn.ReLU(),
                        nn.Linear(hidden_size, output_dim),
                        )
    
    def forward(self, x):
        return self.network(x)


class Policy(nn.Module):

    def __init__(self, state_dim, action_dim, max_action, hidden_size=256):
        super(Policy, self).__init__()
        self.action_dim = action_dim
        self.max_action = max_action
        self.network = MLPNetwork(state_dim, action_dim * 2, hidden_size)

    def forward(self, x, get_logprob=False):
        mu_logstd = self.network(x)
        mu, logstd = mu_logstd.chunk(2, dim=1)
        logstd = torch.clamp(logstd, -20, 2)
        std = logstd.exp()
        dist = Normal(mu, std)
        transforms = [TanhTransform(cache_size=1)]
        dist = TransformedDistribution(dist, transforms)
        action = dist.rsample()
        if get_logprob:
            logprob = dist.log_prob(action).sum(axis=-1, keepdim=True)
        else:
            logprob = None
        mean = torch.tanh(mu)
        
        return action * self.max_action, logprob, mean * self.max_action, mu, logstd

def AvgL1Norm(x, eps=1e-8):
    return x/x.abs().mean(-1, keepdim=True).clamp(min=eps)


# -----------------------------------------------------------
# Timestep embedding used in the DDPM++ and ADM architectures,
# from XXXX
class PositionalEmbedding(nn.Module):
    def __init__(self, dim: int, max_positions: int = 10000, endpoint: bool = False):
        super().__init__()
        self.dim = dim
        self.max_positions = max_positions
        self.endpoint = endpoint

    def forward(self, x):
        freqs = torch.arange(
            start=0, end=self.dim // 2, dtype=torch.float32, device=x.device
        )
        freqs = freqs / (self.dim // 2 - (1 if self.endpoint else 0))
        freqs = (1 / self.max_positions) ** freqs
        x = x.ger(freqs.to(x.dtype))
        x = torch.cat([x.cos(), x.sin()], dim=1)
        return x


class DiffusionMLPNetwork(nn.Module):
    
    def __init__(self, input_dim, emb_dim=16, hidden_size=256, timestep_emb_params: Optional[dict] = None):
        super(DiffusionMLPNetwork, self).__init__()
        
        self.network = nn.Sequential(
                        nn.Linear(input_dim + emb_dim, hidden_size),
                        nn.ReLU(),
                        nn.Linear(hidden_size, hidden_size),
                        nn.ReLU(),
                        nn.Linear(hidden_size, input_dim),
                        )
        
        self.map_noise = PositionalEmbedding(emb_dim, **timestep_emb_params)
    
    def forward(self, x: torch.Tensor, noise: torch.Tensor, condition: Optional[torch.Tensor] = None):
        """
        Forward pass of the diffusion model.
        :param x: Input tensor of shape (batch_size, input_dim)
        :param noise: Noise tensor of shape (batch_size,)
        :param condition: Optional condition tensor of shape (batch_size, hidden_size)
        :return: Output tensor of shape (batch_size, input_dim)
        """
        
        t = self.map_noise(noise)
        if condition is not None:
            t += condition
        else:
            t += torch.zeros_like(t)
        return self.network(torch.cat([x, t], -1))


class DDPMScheduler_DADiff(DDPMScheduler):
    def get_loss_coefficient(self, timestep: torch.Tensor) -> torch.Tensor:
        """
        Get the loss coefficient for a given timestep.
        :param timestep: Tensor of timesteps
        :return: Tensor of loss coefficients
        """
        # The loss coefficient is defined as:
        # beta_t ^ 2 / (2 * alpha_t * (1 - alpha_prod_t) * variance),
        # where variance = (1 - alpha_prod_t_prev) * beta_t / (1 - alpha_prod_t)
    
        t = timestep
        prev_t = self.previous_timestep(t)
    
        alpha_prod_t = self.alphas_cumprod[t]
        alpha_prod_t_prev = self.alphas_cumprod[prev_t]
        alpha_prod_t_prev[prev_t < 0] = self.one
        
        current_alpha_t = alpha_prod_t / alpha_prod_t_prev
        current_beta_t = 1 - current_alpha_t
        
        variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * current_beta_t
        loss_coeff = (current_beta_t ** 2) / (2 * current_alpha_t * (1 - alpha_prod_t) * variance)
        
        # clip the loss coefficient to avoid numerical instability
        loss_coeff = torch.clamp(loss_coeff, min=0, max=1)
        
        return loss_coeff


# trajectory inpainting diffusion model
class DiffusionModel(nn.Module):
    def __init__(self,
            state_dim,
            action_dim,
            model: DiffusionMLPNetwork,
            scheduler: DDPMScheduler,
            num_inference_timesteps=None,
        ):
        super(DiffusionModel, self).__init__()

        self.state_dim = state_dim
        self.action_dim = action_dim

        self.model = model
        self.scheduler = scheduler
        
        if num_inference_timesteps is None:
            num_inference_timesteps = scheduler.config.num_train_timesteps
        self.num_inference_timesteps = num_inference_timesteps
        
    def forward(self,
            state_batch,
            action_batch,
            nextstate_batch=None,
            generator=None,
            **kwargs
        ):
        
        sample_num = self.num_inference_timesteps - 1
        
        timesteps = torch.arange(
            1, self.num_inference_timesteps,
            dtype=torch.long,
            device=state_batch.device
        )
        timesteps = einops.repeat(timesteps, 's -> (b s)', b=state_batch.shape[0])

        # repeat the state and action batch for the number of samples
        state_batch = einops.repeat(state_batch, 'b d -> (b s) d', s=sample_num)
        action_batch = einops.repeat(action_batch, 'b d -> (b s) d', s=sample_num)
        nextstate_batch = einops.repeat(nextstate_batch, 'b d -> (b s) d', s=sample_num)

        loss, timesteps = self.single_step_pred(
            state_batch,
            action_batch,
            nextstate_batch,
            timesteps=timesteps,
            generator=generator
        )
        
        # get the loss coefficient
        loss_coeff = self.scheduler.get_loss_coefficient(timesteps)
        
        # apply the loss coefficient
        loss = loss_coeff * loss
        
        # compute the loss
        loss = einops.rearrange(loss, '(b s) -> b s', s=sample_num)
        loss = loss.mean(dim=1, keepdim=True)

        return loss

    def single_step_pred(self,
            state_batch, 
            action_batch, 
            nextstate_batch,
            timesteps=None,
            generator=None,
        ):
        """
        state_batch: (batch_size, state_dim)
        action_batch: (batch_size, action_dim)
        """
        
        # concatenate state and action along the last dimension
        trajectory = torch.cat((state_batch, action_batch, nextstate_batch), dim=-1)
        
        # sample random noise
        noise = torch.randn(
            size=trajectory.shape, 
            dtype=trajectory.dtype,
            device=trajectory.device,
            generator=generator
        )
        
        # sample random timesteps
        if timesteps is None:
            timesteps = torch.randint(
                0, self.num_inference_timesteps,
                (trajectory.shape[0],),
                device=trajectory.device
            )
        
        # add noise to the trajectory
        noisy_trajectory = self.scheduler.add_noise(
            trajectory, noise, timesteps
        )
        
        # apply conditioning
        condition_mask = torch.ones(
            (noisy_trajectory.shape[0], self.state_dim + self.action_dim), 
            device=trajectory.device
        ).bool()
        horizon_mask = torch.zeros(
            (noisy_trajectory.shape[0], self.state_dim), 
            device=trajectory.device
        ).bool()
        condition_mask = torch.cat(
            [condition_mask, horizon_mask],
            dim=1
        )
        noisy_trajectory[condition_mask] = trajectory[condition_mask]
        
        # predict model output
        pred = self.model(noisy_trajectory, timesteps)
        
        pred_type = self.scheduler.config.prediction_type
        if pred_type == 'epsilon':
            target = noise
        elif pred_type == 'sample':
            target = trajectory
        elif pred_type == 'v_prediction':
            target = self.scheduler.get_velocity(
                trajectory, noise, timesteps
            )
        else:
            raise ValueError(f'Unsupported prediction type: {pred_type}')
        
        # compute the loss
        loss_mask = ~condition_mask
        loss = self.mask_mse(pred, target, loss_mask)
      
        return loss, timesteps
    
    def train_loss(self, 
            state_batch, 
            action_batch, 
            nextstate_batch,
            generator=None
        ):
        """
        Compute the training loss for the diffusion model.
        :param state_batch: a batch of states, shape (batch_size, state_dim)
        :param action_batch: a batch of actions, shape (batch_size, action_dim)
        :param horizon: the horizon step for the loss calculation
        :param generator: a random number generator for reproducibility
        :return: the computed loss value
        """
        
        loss, _ = self.single_step_pred(
            state_batch,
            action_batch,
            nextstate_batch,
            generator=generator
        )
        loss = loss.mean()
          
        return loss
    
    def mask_mse(self, pred, target, mask):
        """
        Compute the masked loss between predicted and target values.
        :param pred: predicted values, shape (batch_size, dim)
        :param target: target values, shape (batch_size, dim)
        :param mask: boolean mask indicating which values to consider in the loss
        :return: masked loss value
        """
        
        if mask is None:
            mask = torch.ones_like(pred, dtype=torch.bool)
            

        loss = F.mse_loss(pred, target, reduction='none')
        loss = loss * mask.type(loss.dtype)
        return loss.sum(dim=1) / (mask.sum(dim=1) + 1e-8)


class DoubleQFunc(nn.Module):
    
    def __init__(self, state_dim, action_dim, hidden_size=256):
        super(DoubleQFunc, self).__init__()
        self.network1 = MLPNetwork(state_dim + action_dim, 1, hidden_size)
        self.network2 = MLPNetwork(state_dim + action_dim, 1, hidden_size)

    def forward(self, state, action):
        x = torch.cat((state, action), dim=1)
        return self.network1(x), self.network2(x)


class DADiff(object):

    def __init__(self,
                 config,
                 device,
                 target_entropy=None,
                 ):
        self.config = config
        self.device = device
        self.discount = config['gamma']
        self.tau = config['tau']
        self.target_entropy = target_entropy if target_entropy else -config['action_dim']
        self.update_interval = config['update_interval']

        self.total_it = 0

        # aka critic
        self.q_funcs = DoubleQFunc(config['state_dim'], config['action_dim'], hidden_size=config['hidden_sizes']).to(self.device)
        self.target_q_funcs = copy.deepcopy(self.q_funcs)
        self.target_q_funcs.eval()
        for p in self.target_q_funcs.parameters():
            p.requires_grad = False

        # aka actor
        self.policy = Policy(config['state_dim'], config['action_dim'], config['max_action'], hidden_size=config['hidden_sizes']).to(self.device)

        # aka encoder
        model = DiffusionMLPNetwork(
            input_dim=config['state_dim'] * 2 + config['action_dim'],
            hidden_size=config['hidden_sizes'],
            timestep_emb_params= {
                'max_positions': config['diffusion_step'],
            }
        ).to(self.device)
        diffusion_scheduler = DDPMScheduler_DADiff(
            num_train_timesteps=config['diffusion_step'],
            beta_start=0.0001,
            beta_end=0.02,
            beta_schedule='squaredcos_cap_v2', # 'linear',
            prediction_type='epsilon',  # 'epsilon', 'sample', 'v_prediction'
        )
        self.encoder = DiffusionModel(
            state_dim=config['state_dim'],
            action_dim=config['action_dim'],
            model=model, 
            scheduler=diffusion_scheduler, 
        ).to(self.device)
        self.encoder_target = copy.deepcopy(self.encoder)
        self.encoder_target.eval()
        
        # aka temperature
        if config['temperature_opt']:
            self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device)
        else:
            self.log_alpha = torch.log(torch.FloatTensor([config['alpha']])).to(self.device)

        self.q_optimizer = torch.optim.Adam(self.q_funcs.parameters(), lr=config['critic_lr'])
        self.policy_optimizer = torch.optim.Adam(self.policy.parameters(), lr=config['actor_lr'])
        self.temp_optimizer = torch.optim.Adam([self.log_alpha], lr=config['actor_lr'])
        self.encoder_optimizer = torch.optim.Adam(self.encoder.parameters(), lr=config['actor_lr'])
    
    def select_action(self, state, test=True):
        with torch.no_grad():
            action, _, mean, _, _ = self.policy(torch.Tensor(state).view(1,-1).to(self.device))
        if test:
            return mean.squeeze().cpu().numpy()
        else:
            return action.squeeze().cpu().numpy()

    def update_target(self):
        """moving average update of target networks"""
        with torch.no_grad():
            for target_q_param, q_param in zip(self.target_q_funcs.parameters(), self.q_funcs.parameters()):
                target_q_param.data.copy_(self.tau * q_param.data + (1.0 - self.tau) * target_q_param.data)
            
            # update encoder
            for target_q_param, q_param in zip(self.encoder_target.parameters(), self.encoder.parameters()):
                target_q_param.data.copy_(self.tau * q_param.data + (1.0 - self.tau) * target_q_param.data)

    def update_q_functions(self, state_batch, action_batch, reward_batch, nextstate_batch, not_done_batch, importance_weighting, writer=None):
        with torch.no_grad():
            nextaction_batch, logprobs_batch, _, mean, logstd = self.policy(nextstate_batch, get_logprob=True)
            q_t1, q_t2 = self.target_q_funcs(nextstate_batch, nextaction_batch)
            # take min to mitigate positive bias in q-function training
            q_target = torch.min(q_t1, q_t2)
            value_target = reward_batch + not_done_batch * self.discount * (q_target - self.alpha * logprobs_batch)

        q_1, q_2 = self.q_funcs(state_batch, action_batch)
        if writer is not None and self.total_it % 2000 == 0:
            writer.add_scalar('train/q1', q_1.mean(), self.total_it)
            writer.add_scalar('train/logprob', logprobs_batch.mean(), self.total_it)
            writer.add_scalar('train/mean', mean.mean(), self.total_it)
            writer.add_scalar('train/logstd', logstd.mean(), self.total_it)
        if importance_weighting is None:
            loss = F.mse_loss(q_1, value_target) + F.mse_loss(q_2, value_target)
        else:
            loss = (importance_weighting * (q_1 - value_target)**2).mean() + (importance_weighting * (q_2 - value_target)**2).mean()
        return loss

    def update_policy_and_temp(self, state_batch):
        action_batch, logprobs_batch, _, _, _ = self.policy(state_batch, get_logprob=True)
        q_b1, q_b2 = self.q_funcs(state_batch, action_batch)
        qval_batch = torch.min(q_b1, q_b2)
        policy_loss = (self.alpha * logprobs_batch - qval_batch).mean()
        temp_loss = -self.alpha * (logprobs_batch.detach() + self.target_entropy).mean()
        return policy_loss, temp_loss
    
    
    def update_encoder(self, src_replay_buffer, tar_replay_buffer, batch_size, writer=None):
        epochs = 10
        
        encoder_loss_list = []
        for e in range(epochs):
            state_batch, action_batch, nextstate_batch, _, _ = tar_replay_buffer.sample(batch_size * 2)
            
            encoder_loss = self.encoder.train_loss(state_batch, action_batch, nextstate_batch)

            self.encoder_optimizer.zero_grad()
            encoder_loss.backward()
            self.encoder_optimizer.step()
            
            encoder_loss_list.append(encoder_loss.item())

        if writer is not None and self.total_it % 2000 == 0:
            writer.add_scalar('train/encoder loss', np.mean(encoder_loss_list), global_step=self.total_it)


    def train(self, src_replay_buffer, tar_replay_buffer, batch_size=128, writer=None):
        self.total_it += 1

        # update the encoder and the agent only given some certain amount of data
        if src_replay_buffer.size < batch_size or tar_replay_buffer.size < batch_size:
            return

        src_state, src_action, src_next_state, src_reward, src_not_done = src_replay_buffer.sample(batch_size)
        tar_state, tar_action, tar_next_state, tar_reward, tar_not_done = tar_replay_buffer.sample(batch_size)

        # update encoder
        if self.total_it % 200 == 0:
            self.update_encoder(src_replay_buffer, tar_replay_buffer, batch_size, writer)
        
        # derive representation deviation
        with torch.no_grad():
                distance = self.encoder_target(src_state, src_action, src_next_state)
        
        if self.total_it > self.config['start_gate_src_sample']:
            pho_distance = 1 / (1 + distance)
            kl_gate_threshold = torch.quantile(
                pho_distance,
                q=self.config['likelihood_gate_threshold']
            )
            accept_gate = (pho_distance > kl_gate_threshold).long()
        else:
            accept_gate = torch.ones_like(distance, dtype=torch.long, device=self.device)

        if writer is not None and self.total_it % 2000 == 0:
            writer.add_scalar('train/distance', distance.mean(), self.total_it)
            writer.add_scalar('train/src reward', src_reward.mean(), self.total_it)
        
        state = torch.cat([src_state, tar_state], 0)
        action = torch.cat([src_action, tar_action], 0)
        next_state = torch.cat([src_next_state, tar_next_state], 0)
        reward = torch.cat([src_reward, tar_reward], 0)
        not_done = torch.cat([src_not_done, tar_not_done], 0)
        importance_weighting = torch.cat([
            accept_gate,
            torch.ones_like(tar_reward, dtype=torch.long)
        ])
        
        q_loss_step = self.update_q_functions(state, action, reward, next_state, not_done, importance_weighting, writer)

        self.q_optimizer.zero_grad()
        q_loss_step.backward()
        self.q_optimizer.step()

        self.update_target()

        # update policy and temperature parameter
        for p in self.q_funcs.parameters():
            p.requires_grad = False
        pi_loss_step, a_loss_step = self.update_policy_and_temp(state)
        self.policy_optimizer.zero_grad()
        pi_loss_step.backward()
        self.policy_optimizer.step()


        if self.config['temperature_opt']:
            self.temp_optimizer.zero_grad()
            a_loss_step.backward()
            self.temp_optimizer.step()

        for p in self.q_funcs.parameters():
            p.requires_grad = True

    @property
    def alpha(self):
        return self.log_alpha.exp()
    
    def save(self, filename: str):
        torch.save({
            'q_funcs': self.q_funcs.state_dict(),
            'target_q_funcs': self.target_q_funcs.state_dict(),
            'policy': self.policy.state_dict(),
            'encoder': self.encoder.state_dict(),
            'encoder_target': self.encoder_target.state_dict(),
            'log_alpha': self.log_alpha,
            'q_optimizer': self.q_optimizer.state_dict(),
            'policy_optimizer': self.policy_optimizer.state_dict(),
            'temp_optimizer': self.temp_optimizer.state_dict(),
            'encoder_optimizer': self.encoder_optimizer.state_dict()
        }, filename + '.pth')
        
    def load(self, filename: str):
        checkpoint = torch.load(filename+'.pth', map_location=self.device)
        self.q_funcs.load_state_dict(checkpoint['q_funcs'])
        self.target_q_funcs.load_state_dict(checkpoint['target_q_funcs'])
        self.policy.load_state_dict(checkpoint['policy'])
        self.encoder.load_state_dict(checkpoint['encoder'])
        self.encoder_target.load_state_dict(checkpoint['encoder_target'])
        self.log_alpha = checkpoint['log_alpha']
        self.q_optimizer.load_state_dict(checkpoint['q_optimizer'])
        self.policy_optimizer.load_state_dict(checkpoint['policy_optimizer'])
        self.temp_optimizer.load_state_dict(checkpoint['temp_optimizer'])
        self.encoder_optimizer.load_state_dict(checkpoint['encoder_optimizer'])
        
    def cal_distance(self, state, action, next_state):
        with torch.no_grad():
            distance = self.encoder_target(state, action, next_state)
            
        return distance