import logging
import numpy as np

import torch
import torch.optim as optim

from safelife.helper_utils import load_kwargs
from safelife.random import get_rng


from .utils import named_output, round_up
from .base_algo import BaseAlgo
from torch.nn import functional as F

logger = logging.getLogger(__name__)
USE_CUDA = torch.cuda.is_available()

class torch_wassdist:
    compute_device = torch.device('cuda' if USE_CUDA else 'cpu')

    def __init__(self, params):
        self.sigma = 1.0
        self.gamma = 1.0
        self.T = 10
        self.n = 1  # num of data points per batch to update kernel function
        self.xdim = 9  
        self.alpha = 5e-2
        self.w_iter = int(0.8*params.wd_dim)
        self.D = params.wd_dim  # random feature dimension
        self.omega = torch.randn(self.D, self.xdim, device=self.compute_device, dtype=torch.float32) * 1.0 / self.sigma
        self.bias = torch.rand(self.D, device=self.compute_device, dtype=torch.float32) * 2 * np.pi
        self.t = 0
        self.b = 0
        self.num_minibatches = params.num_minibatches

    def get_random_feature(self, x):
        phi = torch.cos(torch.matmul(self.omega, x) + self.bias) * np.sqrt(2. / self.D)
        return phi

    def update(self, X, Y):
        for _ in range(self.w_iter):
            x = X[int(np.random.uniform() * len(X))]
            y = Y[int(np.random.uniform() * len(X))]

            self.single_update(x, y)

        self.b += 1

        if self.b >= self.num_minibatches:
            self.t = 0
            self.b = 0

    def single_update(self, x, y):
        if self.t == 0:
            self.beta_1 = self.get_random_feature(x)
            self.beta_2 = self.get_random_feature(y)
        else:
            # map to random feature space
            zx = self.get_random_feature(x)
            zy = self.get_random_feature(y)

            C = torch.sum((x - y) ** 2, axis=0)
            coeff = torch.exp((torch.matmul(self.beta_1, zx) - torch.matmul(self.beta_2, zy) - C) / self.gamma)
            weight = (1 - coeff)

            # update the functions
            self.beta_1 += self.alpha * torch.mean(weight * zx, axis=-1)
            self.beta_2 += self.alpha * torch.mean(weight * zy, axis=-1)
        self.t += 1

    def wd(self, x, y):

        zx = self.get_random_feature(x.view(-1, self.D))
        zy = self.get_random_feature(y.view(-1, self.D))
        wd = torch.matmul(self.beta_1, zx) - torch.matmul(self.beta_2, zy) + (1 / self.gamma) * torch.exp(
            (torch.matmul(self.beta_1, zx) - torch.matmul(self.beta_2, zy) - torch.sum((x - y) ** 2)) / self.gamma)
        return torch.mean(wd)


class PPO_Risk(BaseAlgo):
    num_steps_safe = 0
    num_episodes_safe = 0

    checkpoint_freq = 1e6
    num_checkpoints = 3
    data_logger_1 = None
    data_logger_2 = None

    num_steps = 0
    num_minibatches = 4
    epochs_per_batch = 3

    entropy_reg = 0.01
    entropy_clip = 1.0  
    vf_coef = 0.5
    max_gradient_norm = 5.0
    eps_policy = 0.2  # PPO clipping for policy loss
    eps_value = 0.2  # PPO clipping for value loss
    rescale_policy_eps = False
    min_eps_rescale = 1e-3  # Only relevant if rescale_policy_eps = True
    reward_clip = 0.0
    policy_rectifier = 'relu'  


    compute_device = torch.device('cuda' if USE_CUDA else 'cpu')

    training_envs = None
    testing_envs = None


    training_envs_safe = None
    testing_envs_safe = None


    checkpoint_attribs = (
        'num_steps',
        'training_model',
        'train_model_safe',
        'optimizer',
        'optimizer_safe',
        'data_logger.cumulative_stats',
        'champion_dict'
    )

    def __init__(self, training_model, train_model_safe, args, **kwargs):
        load_kwargs(self, kwargs)
        assert self.training_envs is not None

        ###Parameter Loading###
        self.args = args
        self.gamma = args.gamma
        self.lmda = args.lmda
        self.learning_rate = args.learning_rate
        self.training_batch_size = args.training_batch_size
        self.steps_per_env = args.steps_per_env
        self.report_freq = args.report_freq
        self.test_freq = args.test_freq
        self.safe_reg = args.safe_reg

        self.args.wd_dim = int(self.steps_per_env * len(self.training_envs) / self.num_minibatches)
        self.args.num_minibatches = int(self.num_minibatches)

        self.training_model = training_model.to(self.compute_device)
        self.optimizer = optim.Adam(
            self.training_model.parameters(), lr=self.learning_rate)

        self.train_model_safe = train_model_safe.to(self.compute_device)

        self.optimizer_safe = optim.Adam(
            self.train_model_safe.parameters(), lr=self.learning_rate)

        obs_dim = self.training_envs[0].observation_space.shape

        self.load_checkpoint()

        self.wd_dual = torch_wassdist(self.args)

        ##Initial Tournament
        self.tournament(self.testing_envs)


    def _float_tensor(self, data):
        return torch.tensor(data, device=self.compute_device, dtype=torch.float32)

    def _long_tensor(self, data):
        return torch.tensor(data, device=self.compute_device, dtype=torch.int64)


    def distance_calculation(self, probs, probs_safe, type="wd-dual"):

        # Scale the q-values before taking the probabilistic distances

        # Add a small constant to avoid divide by zero in the scales
        # reward_scale = self.q_reward_stdev + 1e-10
        # safety_scale = self.q_safe_stdev + 1e-10
        # T = self.q_temperature

        # Calculate the safety loss
        if type == "jensen-shannon":
            p_reward = probs
            p_safe = probs_safe
            p_combo = 0.5 * (p_safe + p_reward)
            p_log_combo = torch.log(p_combo + 1e-20)  # avoid div by zero
            safety_loss = 0.5 * (
                    F.kl_div(p_log_combo, p_safe, reduction='batchmean') +
                    F.kl_div(p_log_combo, p_reward, reduction='batchmean'))
        elif type == "probability_L2":
            p_reward = probs
            p_safe = probs_safe
            safety_loss = (p_reward - p_safe) ** 2
            safety_loss = 0.5 * safety_loss.mean()

        elif type == "wd-dual":
            p_reward = probs.detach()
            p_safe = probs_safe.detach()
            self.wd_dual.update(p_reward, p_safe)
            safety_loss = self.wd_dual.wd(probs, probs_safe)

        

        return safety_loss

    @named_output('states actions rewards done policies values info')
    def take_one_step(self, envs):
        states = [
            e.last_obs if hasattr(e, 'last_obs') else e.reset()
            for e in envs
        ]
        tensor_states = torch.tensor(states, device=self.compute_device, dtype=torch.float32)
        values, policies = self.training_model(tensor_states)
        values = values.detach().cpu().numpy()
        policies = policies.detach().cpu().numpy()
        actions = []
        rewards = []
        dones = []
        infos = []
        for policy, env in zip(policies, envs):
            action = get_rng().choice(len(policy), p=policy)
            obs, reward, done, info = env.step(action)
            if done:
                obs = env.reset()
            env.last_obs = obs
            actions.append(action)
            rewards.append(reward)
            dones.append(done)
            infos.append(info)
        return states, actions, rewards, dones, policies, values, infos

    @named_output('states actions rewards done policies values')
    def take_one_step_safe(self, envs):
        states = [
            e.last_obs if hasattr(e, 'last_obs') else e.reset()
            for e in envs
        ]
        tensor_states = torch.tensor(states, device=self.compute_device, dtype=torch.float32)
        values, policies = self.train_model_safe(tensor_states)
        values = values.detach().cpu().numpy()
        policies = policies.detach().cpu().numpy()
        actions = []
        rewards = []
        dones = []
        for policy, env in zip(policies, envs):
            action = get_rng().choice(len(policy), p=policy)
            obs, reward, done, info = env.step(action)
            if done:
                obs = env.reset()
            env.last_obs = obs
            side_effect_0 = -float(env.delta_effect)
            actions.append(action)
            rewards.append(side_effect_0)
            dones.append(done)
        return states, actions, rewards, dones, policies, values

    @named_output('states actions action_prob returns advantages values')
    def gen_training_batch(self, steps_per_env, flat=True):
        """
        Run each environment a number of steps and calculate advantages.

        Parameters
        ----------
        steps_per_env : int
            Number of steps to take per environment.
        flat : bool
            If True, each output tensor will have shape
            ``(steps_per_env * num_env, ...)``.
            Otherwise, shape will be ``(steps_per_env, num_env, ...)``.
        """
        steps = [
            self.take_one_step(self.training_envs)
            for _ in range(steps_per_env)
        ]
        final_states = [e.last_obs for e in self.training_envs]
        tensor_states = torch.tensor(
            final_states, device=self.compute_device, dtype=torch.float32)
        final_vals = self.training_model(tensor_states)[0].detach().cpu().numpy()
        values = np.array([s.values for s in steps] + [final_vals])
        rewards = np.array([s.rewards for s in steps])
        done = np.array([s.done for s in steps])
        reward_mask = ~done

        # Calculate the discounted rewards
        gamma = self.gamma
        lmda = self.lmda
        returns = rewards.copy()
        returns[-1] += gamma * final_vals * reward_mask[-1]
        advantages = rewards + gamma * reward_mask * values[1:] - values[:-1]
        for i in range(steps_per_env - 2, -1, -1):
            returns[i] += gamma * reward_mask[i] * returns[i + 1]
            advantages[i] += lmda * reward_mask[i] * advantages[i + 1]

        # Calculate the probability of taking each selected action
        policies = np.array([s.policies for s in steps])
        actions = np.array([s.actions for s in steps])
        probs = np.take_along_axis(
            policies, actions[..., np.newaxis], axis=-1)[..., 0]

        def t(x, dtype=torch.float32):
            if flat:
                x = np.asanyarray(x)
                x = x.reshape(-1, *x.shape[2:])
            return torch.tensor(x, device=self.compute_device, dtype=dtype)

        self.num_steps += actions.size

        return (
            t([s.states for s in steps]), t(actions, torch.int64),
            t(probs), t(returns), t(advantages), t(values[:-1])
        )

    @named_output('states actions action_prob returns advantages values')
    def gen_training_batch_safe(self, steps_per_env, flat=True):
        """
        Run each environment a number of steps and calculate advantages.

        Parameters
        ----------
        steps_per_env : int
            Number of steps to take per environment.
        flat : bool
            If True, each output tensor will have shape
            ``(steps_per_env * num_env, ...)``.
            Otherwise, shape will be ``(steps_per_env, num_env, ...)``.
        """
        steps = [
            self.take_one_step_safe(self.training_envs_safe)
            for _ in range(steps_per_env)
        ]
        final_states = [e.last_obs for e in self.training_envs_safe]
        tensor_states = torch.tensor(
            final_states, device=self.compute_device, dtype=torch.float32)
        final_vals = self.train_model_safe(tensor_states)[0].detach().cpu().numpy()
        values = np.array([s.values for s in steps] + [final_vals])
        rewards = np.array([s.rewards for s in steps])
        done = np.array([s.done for s in steps])
        reward_mask = ~done

        # Calculate the discounted rewards
        gamma = self.gamma
        lmda = self.lmda
        returns = rewards.copy()
        returns[-1] += gamma * final_vals * reward_mask[-1]
        advantages = rewards + gamma * reward_mask * values[1:] - values[:-1]
        for i in range(steps_per_env - 2, -1, -1):
            returns[i] += gamma * reward_mask[i] * returns[i + 1]
            advantages[i] += lmda * reward_mask[i] * advantages[i + 1]

        # Calculate the probability of taking each selected action
        policies = np.array([s.policies for s in steps])
        actions = np.array([s.actions for s in steps])
        probs = np.take_along_axis(
            policies, actions[..., np.newaxis], axis=-1)[..., 0]

        def t(x, dtype=torch.float32):
            if flat:
                x = np.asanyarray(x)
                x = x.reshape(-1, *x.shape[2:])
            return torch.tensor(x, device=self.compute_device, dtype=dtype)

        self.num_steps_safe += actions.size

        return (
            t([s.states for s in steps]), t(actions, torch.int64),
            t(probs), t(returns), t(advantages), t(values[:-1])
        )

    def calculate_loss(
            self, states, actions, old_policy, old_values, returns, advantages, report):
        """
        All parameters ought to be tensors on the appropriate compute device.
        """
        values, policy = self.training_model(states)
        a_policy = torch.gather(policy, -1, actions[..., np.newaxis])[..., 0]

        prob_diff = advantages.sign() * (1 - a_policy / old_policy)
        policy_loss = advantages.abs() * torch.clamp(prob_diff, min=-self.eps_policy)
        policy_loss = policy_loss.mean()

        v_clip = old_values + torch.clamp(
            values - old_values, min=-self.eps_value, max=+self.eps_value)
        value_loss = torch.max((v_clip - returns) ** 2, (values - returns) ** 2)
        value_loss = value_loss.mean()

        entropy = torch.sum(-policy * torch.log(policy + 1e-12), dim=-1)
        entropy_loss = torch.clamp(entropy.mean(), max=self.entropy_clip)
        entropy_loss *= -self.entropy_reg

        values_safe, policies_safe = self.train_model_safe(states)

        values_safe = values_safe.detach()
        policies_safe = policies_safe.detach()

        safe_loss = self.distance_calculation(policy, policies_safe)

        loss = policy_loss + value_loss * self.vf_coef + entropy_loss + self.safe_reg * safe_loss

        if report and self.data_logger is not None:
            loss_data = {
                "policy_loss": policy_loss.item(),
                "safe_loss": safe_loss.item(),
                "value_loss": value_loss.item(),
                "entropy_loss": entropy_loss.item(),
            }

            schedule_data = {
                "loss_reg": self.safe_reg,
            }

            logger.info(
                "n=%i: policy_loss=%0.3g, safe_loss=%0.3g",
                self.num_steps,
                loss_data['policy_loss'], loss_data['safe_loss'])

            self.data_logger.log_scalars(loss_data, self.num_steps, 'losses')
            self.data_logger.log_scalars(schedule_data, self.num_steps, 'schedules')

        return entropy, loss

    def calculate_loss_safe(
            self, states, actions, old_policy, old_values, returns, advantages):
        """
        All parameters ought to be tensors on the appropriate compute device.
        """
        values, policy = self.train_model_safe(states)
        a_policy = torch.gather(policy, -1, actions[..., np.newaxis])[..., 0]

        prob_diff = advantages.sign() * (1 - a_policy / old_policy)
        policy_loss = advantages.abs() * torch.clamp(prob_diff, min=-self.eps_policy)
        policy_loss = policy_loss.mean()

        v_clip = old_values + torch.clamp(
            values - old_values, min=-self.eps_value, max=+self.eps_value)
        value_loss = torch.max((v_clip - returns) ** 2, (values - returns) ** 2)
        value_loss = value_loss.mean()

        entropy = torch.sum(-policy * torch.log(policy + 1e-12), dim=-1)
        entropy_loss = torch.clamp(entropy.mean(), max=self.entropy_clip)
        entropy_loss *= -self.entropy_reg

        loss = policy_loss + value_loss * self.vf_coef + entropy_loss 

        return entropy, loss

    def train_batch(self, batch, report=False):
        idx = np.arange(len(batch.states))

        for _ in range(self.epochs_per_batch):
            get_rng().shuffle(idx)
            for k in idx.reshape(self.num_minibatches, -1):
                entropy, loss = self.calculate_loss(
                    batch.states[k], batch.actions[k], batch.action_prob[k],
                    batch.values[k], batch.returns[k], batch.advantages[k],
                    report)
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

    def train_batch_safe(self, batch):
        idx = np.arange(len(batch.states))

        for _ in range(self.epochs_per_batch):
            get_rng().shuffle(idx)
            for k in idx.reshape(self.num_minibatches, -1):
                entropy, loss_safe = self.calculate_loss_safe(
                    batch.states[k], batch.actions[k], batch.action_prob[k],
                    batch.values[k], batch.returns[k], batch.advantages[k])
                self.optimizer_safe.zero_grad()
                loss_safe.backward()
                self.optimizer_safe.step()

    def train(self, steps):
        max_steps = self.num_steps + steps
        needs_report = True

        while self.num_steps < max_steps:
            next_report = round_up(self.num_steps_safe, self.report_freq)
            next_test = round_up(self.num_steps, self.test_freq)

            batch_safe = self.gen_training_batch_safe(self.steps_per_env)
            self.train_batch_safe(batch_safe)

            if self.num_steps_safe >= next_report:
                needs_report = True

            batch = self.gen_training_batch(self.steps_per_env)
            self.train_batch(batch, needs_report)
            needs_report = False

            self.save_checkpoint_if_needed()

            num_steps = self.num_steps


            if self.testing_envs and num_steps >= next_test:
                self.tournament(self.testing_envs)