import copy, math
import numpy as np
import torch
from torch import Tensor, nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from utils.logger import Logger

from agents.trajectory import Trajectory
from agents.critic import Critic
from agents.model import MLP
from agents.helpers import EMA, kerras_boundaries

class ExponentialScheduler():
    def __init__(self, v_start=1.5, v_final=0.2, decay_steps=2*10**2):
        """A scheduler for exponential decay.

        :param v_start: starting value of epsilon, default 1. as purely random policy 
        :type v_start: float
        :param v_final: final value of epsilon
        :type v_final: float
        :param decay_steps: number of steps from eps_start to eps_final
        :type decay_steps: int
        """
        self.v_start = v_start
        self.v_final = v_final
        self.decay_steps = decay_steps
        self.value = self.v_start
        self.ini_frame_idx = 0
        self.current_frame_idx = 0

    def reset(self, ):
        """ Reset the scheduler """
        self.ini_frame_idx = self.current_frame_idx

    def step(self):
        """
        The choice of eps_decay:
        ------------------------
        start = 1
        final = 0.01
        decay = 10**6  # the decay steps can be 1/10 over all steps 10000*1000
        final + (start-final)*np.exp(-1*(10**7)/decay)
        
        => 0.01

        """
        self.current_frame_idx += 1
        delta_frame_idx = self.current_frame_idx - self.ini_frame_idx
        self.value = self.v_final + (self.v_start - self.v_final) * math.exp(-1. * delta_frame_idx / self.decay_steps)
        return self.value

    def get_value(self, ):
        return self.value

# https://github.com/Kinyugo/consistency_models/blob/main/consistency_models/consistency_models.py
def timesteps_schedule(
    current_training_step: int,
    total_training_steps: int,
    initial_timesteps: int = 2,
    final_timesteps: int = 150,
) -> int:
    """Implements the proposed timestep discretization schedule.

    Parameters
    ----------
    current_training_step : int
        Current step in the training loop.
    total_training_steps : int
        Total number of steps the model will be trained for.
    initial_timesteps : int, default=2
        Timesteps at the start of training.
    final_timesteps : int, default=150
        Timesteps at the end of training.

    Returns
    -------
    int
        Number of timesteps at the current point in training.
    """
    num_timesteps = final_timesteps**2 - initial_timesteps**2
    num_timesteps = current_training_step * num_timesteps / total_training_steps
    num_timesteps = math.ceil(math.sqrt(num_timesteps + initial_timesteps**2) - 1)

    return num_timesteps + 1


def improved_timesteps_schedule(
    current_training_step: int,
    total_training_steps: int,
    initial_timesteps: int = 10,
    final_timesteps: int = 1280,
) -> int:
    """Implements the improved timestep discretization schedule.

    Parameters
    ----------
    current_training_step : int
        Current step in the training loop.
    total_training_steps : int
        Total number of steps the model will be trained for.
    initial_timesteps : int, default=2
        Timesteps at the start of training.
    final_timesteps : int, default=150
        Timesteps at the end of training.

    Returns
    -------
    int
        Number of timesteps at the current point in training.

    References
    ----------
    [1] [Improved Techniques For Consistency Training](https://arxiv.org/pdf/2310.14189.pdf)
    """
    total_training_steps_prime = math.floor(
        total_training_steps
        / (math.log2(math.floor(final_timesteps / initial_timesteps)) + 1)
    )
    num_timesteps = initial_timesteps * math.pow(
        2, math.floor(current_training_step / total_training_steps_prime)
    )
    num_timesteps = min(num_timesteps, final_timesteps) + 1

    return int(num_timesteps)


def lognormal_timestep_distribution(
    num_samples: int,
    sigmas: Tensor,
    mean: float = -1.1,
    std: float = 2.0,
) -> Tensor:
    """Draws timesteps from a lognormal distribution.

    Parameters
    ----------
    num_samples : int
        Number of samples to draw.
    sigmas : Tensor
        Standard deviations of the noise.
    mean : float, default=-1.1
        Mean of the lognormal distribution.
    std : float, default=2.0
        Standard deviation of the lognormal distribution.

    Returns
    -------
    Tensor
        Timesteps drawn from the lognormal distribution.

    References
    ----------
    [1] [Improved Techniques For Consistency Training](https://arxiv.org/pdf/2310.14189.pdf)
    """
    pdf = torch.erf((torch.log(sigmas[1:]) - mean) / (std * math.sqrt(2))) - torch.erf(
        (torch.log(sigmas[:-1]) - mean) / (std * math.sqrt(2))
    )
    pdf = pdf / pdf.sum()

    timesteps = torch.multinomial(pdf, num_samples, replacement=True)

    return timesteps


class Agent(nn.Module):
    def __init__(self, 
                 state_dim, 
                 action_dim, 
                 max_action,
                 device,
                 discount,
                 tau,
                 max_q_backup=False,
                 sigma_max=80.0, 
                 sigma_min=0.002,  
                 rho=7, 
                 eta = 1.0,
                 n_time_steps=40, 
                 ema_decay=0.995,
                 step_start_ema=1000,
                 update_ema_every=5,
                 lr=3e-4,
                 lr_decay=False,
                 lr_maxt=1000,
                 grad_norm=1.0,
                 q_norm=False,
                 adaptive_ema=False,
                 steps_per_epoch=1000,
                 improved=False,
                 use_compile=True,  
                 ):
        super().__init__()

        self.device = device
        self.model = MLP(state_dim, action_dim, device=device, use_time_s=True)

        self.state_dim = state_dim
        self.max_action = max_action
        self.action_dim = action_dim
        self.discount = discount
        self.tau = tau
        self.eta = eta
        self.max_q_backup = max_q_backup
        self.adaptive_ema = adaptive_ema
        self.steps_per_epoch = steps_per_epoch
        self.use_compile = use_compile

        self.actor = Trajectory(state_dim=state_dim, action_dim=action_dim, model=self.model, max_action=max_action, n_time_steps=n_time_steps).to(device)
        self.critic = Critic(state_dim, action_dim).to(device)
        
        # use torch.compile to optimize the model
        if self.use_compile and hasattr(torch, 'compile'):
            self.model = torch.compile(self.model, mode='max-autotune')
            self.actor = torch.compile(self.actor, mode='max-autotune')
            self.critic = torch.compile(self.critic, mode='max-autotune')
            
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=lr)
        self.critic_target = copy.deepcopy(self.critic)
        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=3e-4)

        self.lr_decay = lr_decay
        self.grad_norm = grad_norm
        self.q_norm = q_norm
        self.improved = improved

        if lr_decay:
            self.actor_scheduler = CosineAnnealingLR(self.actor_optimizer, T_max=lr_maxt, eta_min=0.)
            self.critic_scheduler = CosineAnnealingLR(self.critic_optimizer, T_max=lr_maxt, eta_min=0.)

        # parameters for sampling
        self.sigma_max = sigma_max
        self.sigma_min = sigma_min
        self.rho = rho
        
        self.n_time_steps = n_time_steps
        
        # EMA
        self.step = 0
        self.step_start_ema = step_start_ema
        self.ema = EMA(ema_decay)
        self.ema_actor = copy.deepcopy(self.actor)
        self.update_ema_every = update_ema_every

        for param in self.ema_actor.parameters():
            param.requires_grad = False
    
    def sample_t_idx(self, batch_size, device, n_time_steps=20, num_heun_step=1):
        indice_np = np.random.randint(0, n_time_steps - num_heun_step, batch_size)
        indices = torch.from_numpy(indice_np).to(device)
        return indices
    
    def sample_s_idx(self, indices, n_time_steps=20, num_heun_step=1):
        new_indices = torch.from_numpy(np.random.randint((indices + num_heun_step).cpu().detach().numpy(), n_time_steps, indices.shape[0])).to(indices.device)
        return new_indices

    def step_ema(self):
        if self.step < self.step_start_ema:
            return
        self.ema.update_model_average(self.ema_actor, self.actor)

    def train(self, replay_buffer, iterations, batch_size=100, log_writer=None):

        metric = {'ctm_loss': [], 'dsm_loss':[], 'mean_q_weights': [],'actor_loss': [], 'critic_loss': []}


        for i in range(iterations):
            state, action, next_state, reward, not_done = replay_buffer.sample(batch_size)

            # Q Training 
            current_q1, current_q2 = self.critic(state, action)
            q_data = torch.min(current_q1, current_q2).detach()

            if self.max_q_backup:
                next_state_rpt = torch.repeat_interleave(next_state, repeats=10, dim=0)
                next_action_rpt = self.ema_actor(next_state_rpt)
                target_q1, target_q2 = self.critic_target(next_state_rpt, next_action_rpt)
                target_q1 = target_q1.view(batch_size, 10).max(dim=1, keepdim=True)[0]
                target_q2 = target_q2.view(batch_size, 10).max(dim=1, keepdim=True)[0]
                target_q = torch.min(target_q1, target_q2)
            else:
                next_action = self.ema_actor(next_state)
                target_q1, target_q2 = self.critic_target(next_state, next_action)
                target_q = torch.min(target_q1, target_q2)

            target_q = (reward + not_done * self.discount * target_q).detach()
            
            # normalize the target_q
            if self.q_norm:
                target_q = (target_q) / (target_q.std() + 1e-6)

            critic_loss = F.mse_loss(current_q1, target_q) + F.mse_loss(current_q2, target_q)

            self.critic_optimizer.zero_grad()
            critic_loss.backward()
            if self.grad_norm > 0:
                critic_grad_norms = nn.utils.clip_grad_norm_(self.critic.parameters(), max_norm=self.grad_norm, norm_type=2) # return grad value is before clipped, although the grad for update is clipped
            self.critic_optimizer.step()

            # q-weights calculation
            new_action = self.actor(state)
            q1_new_action, q2_new_action = self.critic(state, new_action)
            q_policy = torch.min(q1_new_action, q2_new_action).detach()

            adv = q_data - q_policy
            adv = adv / (adv.std() + 1e-6)
            adv = adv.clamp(min=0.0)
            q_weights = torch.exp(self.eta * adv).clamp(max=50.0).detach()

            # Policy Training 
            if self.improved:
                N = improved_timesteps_schedule(i, iterations, initial_timesteps=10, final_timesteps=1280)
            else:
                # N = math.ceil(math.sqrt((itr * (150**2 - 4) / iterations) + 4) - 1) + 1  # s0=2, s1=150
                N = timesteps_schedule(i, iterations, initial_timesteps=2, final_timesteps=150) # eqivalent to above
            boundaries = kerras_boundaries(self.rho, self.sigma_max, N, self.sigma_min).to(self.device)

            # CTM Loss
            z = torch.randn_like(action)

            num_heun_step = np.random.randint(1, N - 1)
            t_idx = self.sample_t_idx(batch_size, self.device, N, num_heun_step)
            s_idx = self.sample_s_idx(t_idx, N, num_heun_step)
            u_idx = t_idx + num_heun_step
            t = boundaries[t_idx].view(-1, 1)
            u = boundaries[u_idx].view(-1, 1)
            s = boundaries[s_idx].view(-1, 1)

            # DSM Loss
            log_t_dsm1 = - 1.2 + torch.randn(batch_size//2, device=self.device) * 1.2
            t_dsm1 = torch.exp(log_t_dsm1)
            # other half
            t_dsm2_idx = torch.rand(batch_size - batch_size//2, device=self.device) * 0.7 
            t_dsm2 = (self.sigma_max ** (1 / self.rho) + t_dsm2_idx * (self.sigma_min ** (1 / self.rho) - self.sigma_max ** (1 / self.rho))) ** self.rho

            t_dsm = torch.cat([t_dsm1, t_dsm2], dim=0).view(-1, 1)
            
            teacher_model = self.ema_actor
            
            loss_ctm = self.actor.loss_ctm(state, action, z, t, u, s, teacher_model, weights=q_weights)
            loss_dsm = self.actor.loss_dsm(state, action, z, t_dsm, weights=q_weights)

            self.actor_optimizer.zero_grad()

            lambda_dsm = self.caculate_adpative_weight(loss_ctm, loss_dsm, last_layer=self.actor.get_last_layer_weight())
            actor_loss = loss_ctm + lambda_dsm * loss_dsm

            actor_loss.backward()
            if self.grad_norm > 0: 
                actor_grad_norms = nn.utils.clip_grad_norm_(self.actor.parameters(), max_norm=self.grad_norm, norm_type=2)
            self.actor_optimizer.step()

            """ Step Target Network """
            if self.step % self.update_ema_every == 0:
                self.step_ema()

            for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
            
            self.step += 1
            
            """ Log"""
            if log_writer is not None:
                if self.grad_norm > 0:
                    log_writer.add_scalar('Actor Grad Norm', actor_grad_norms.max().item(), self.step)
                    log_writer.add_scalar('Critic Grad Norm', critic_grad_norms.max().item(), self.step)
                log_writer.add_histogram('Q Weights Distribution', q_weights, self.step)
                log_writer.add_scalar('Mean Q Weights', q_weights.mean().item(), self.step)
                log_writer.add_scalar('Std Q Weights', q_weights.std().item(), self.step)
                log_writer.add_scalar('CTM Loss', loss_ctm.item(), self.step)
                log_writer.add_scalar('DSM Loss', loss_dsm.item(), self.step)
                log_writer.add_scalar('Actor Loss', actor_loss.item(), self.step)
                log_writer.add_scalar('Critic Loss', critic_loss.item(), self.step)
                log_writer.add_scalar('Target_Q Mean', target_q.mean().item(), self.step)
                log_writer.add_scalar('Lambda DSM', lambda_dsm.item(), self.step)

            metric['ctm_loss'].append(loss_ctm.item())
            metric['dsm_loss'].append(loss_dsm.item())
            metric['mean_q_weights'].append(q_weights.mean().item())
            metric['actor_loss'].append(actor_loss.item())
            metric['critic_loss'].append(critic_loss.item())

        if self.lr_decay:
            self.actor_scheduler.step()
            self.critic_scheduler.step()

        return metric

    def caculate_adpative_weight(self, loss1, loss2, last_layer=None):
        loss1_grad = torch.autograd.grad(loss1, last_layer, retain_graph=True)[0]
        loss2_grad = torch.autograd.grad(loss2, last_layer, retain_graph=True)[0]
        d_weight = torch.norm(loss1_grad) / (torch.norm(loss2_grad) + 1e-4)
        d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
        return d_weight
    
    def sample_action(self, state):
        state = torch.FloatTensor(state.reshape(1, -1)).to(self.device)
        state_rpt = torch.repeat_interleave(state, repeats=50, dim=0)
        with torch.no_grad():
            action = self.actor(state_rpt)
            q_value = self.critic_target.q_min(state_rpt, action).flatten()
            idx = torch.multinomial(F.softmax(q_value), 1)
        return action[idx].cpu().data.numpy().flatten()

    def save_model(self, dir, id=None):
        if id is not None:
            torch.save(self.actor.state_dict(), f'{dir}/actor_{id}.pth')
            torch.save(self.critic.state_dict(), f'{dir}/critic_{id}.pth')
        else:
            torch.save(self.actor.state_dict(), f'{dir}/actor.pth')
            torch.save(self.critic.state_dict(), f'{dir}/critic.pth')

    def load_model(self, dir, id=None, map_location=None):
        if id is not None:
            self.actor.load_state_dict(torch.load(f'{dir}/actor_{id}.pth', map_location=map_location))
            self.critic.load_state_dict(torch.load(f'{dir}/critic_{id}.pth', map_location=map_location))
        else:
            self.actor.load_state_dict(torch.load(f'{dir}/actor.pth', map_location=map_location))
            self.critic.load_state_dict(torch.load(f'{dir}/critic.pth', map_location=map_location))