# Code modified from spinningup repo.
# Refer[Original Code]: https://github.com/openai/spinningup/blob/master/spinup/algos/pytorch/sac/core.py
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.normal import Normal
from copy import deepcopy
import itertools
import numpy as np
import torch
from torch.optim import Adam
import gymnasium as gym
import time
import metaworld


def combined_shape(length, shape=None):
    if shape is None:
        return (length,)
    return (length, shape) if np.isscalar(shape) else (length, *shape)

def mlp(sizes, activation, output_activation=nn.Identity):
    layers = []
    for j in range(len(sizes)-1):
        act = activation if j < len(sizes)-2 else output_activation
        layers += [nn.Linear(sizes[j], sizes[j+1]), act()]
    return nn.Sequential(*layers)

def count_vars(module):
    return sum([np.prod(p.shape) for p in module.parameters()])


LOG_STD_MAX = 2
LOG_STD_MIN = -20

class SquashedGaussianMLPActor(nn.Module):

    def __init__(self, obs_dim, act_dim, hidden_sizes, activation, act_limit):
        super().__init__()
        self.net = mlp([obs_dim] + list(hidden_sizes), activation, activation)
        self.mu_layer = nn.Linear(hidden_sizes[-1], act_dim)
        self.log_std_layer = nn.Linear(hidden_sizes[-1], act_dim)
        self.act_limit = act_limit

    def forward(self, obs, deterministic=False, with_logprob=True):
        net_out = self.net(obs)
        mu = self.mu_layer(net_out)
        log_std = self.log_std_layer(net_out)
        log_std = torch.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
        std = torch.exp(log_std)

        # Pre-squash distribution and sample
        pi_distribution = Normal(mu, std)
        if deterministic:
            # Only used for evaluating policy at test time.
            pi_action = mu
        else:
            pi_action = pi_distribution.rsample()

        if with_logprob:
            # Compute logprob from Gaussian, and then apply correction for Tanh squashing.
            # NOTE: The correction formula is a little bit magic. To get an understanding
            # of where it comes from, check out the original SAC paper (arXiv 1801.01290)
            # and look in appendix C. This is a more numerically-stable equivalent to Eq 21.
            # Try deriving it yourself as a (very difficult) exercise. :)
            logp_pi = pi_distribution.log_prob(pi_action).sum(axis=-1)
            logp_pi -= (2*(np.log(2) - pi_action - F.softplus(-2*pi_action))).sum(axis=1)
        else:
            logp_pi = None

        pi_action = torch.tanh(pi_action)
        pi_action = self.act_limit * pi_action

        return pi_action, logp_pi


class MLPQFunction(nn.Module):

    def __init__(self, obs_dim, act_dim, hidden_sizes, activation):
        super().__init__()
        self.q = mlp([obs_dim + act_dim] + list(hidden_sizes) + [1], activation)

    def forward(self, obs, act):
        q = self.q(torch.cat([obs, act], dim=-1))
        return torch.squeeze(q, -1) # Critical to ensure q has right shape.

class MLPActorCritic(nn.Module):

    def __init__(self, observation_space, action_space, hidden_sizes=(256,256),
                 activation=nn.ReLU):
        super().__init__()

        obs_dim = observation_space.shape[0]
        act_dim = action_space.shape[0]
        act_limit = action_space.high[0]
        self.device = torch.device('cpu')

        # build policy and value functions
        self.pi = SquashedGaussianMLPActor(obs_dim, act_dim, hidden_sizes, activation, act_limit).to(self.device)
        self.q1 = MLPQFunction(obs_dim, act_dim, hidden_sizes, activation).to(self.device)
        self.q2 = MLPQFunction(obs_dim, act_dim, hidden_sizes, activation).to(self.device)

    def act(self, obs, deterministic=False):
        with torch.no_grad():
            a, _ = self.pi(obs, deterministic, False)
            return a.numpy()


class ReplayBuffer:
    """
    A simple FIFO experience replay buffer for SAC agents.
    """

    def __init__(self, obs_dim, act_dim, size):
        self.obs_buf = np.zeros(combined_shape(size, obs_dim), dtype=np.float32)
        self.obs2_buf = np.zeros(combined_shape(size, obs_dim), dtype=np.float32)
        self.act_buf = np.zeros(combined_shape(size, act_dim), dtype=np.float32)
        self.rew_buf = np.zeros(size, dtype=np.float32)
        self.done_buf = np.zeros(size, dtype=np.float32)
        self.ptr, self.size, self.max_size = 0, 0, size
        self.device = torch.device('cpu')

    def store(self, obs, act, rew, next_obs, done):
        self.obs_buf[self.ptr] = obs
        self.obs2_buf[self.ptr] = next_obs
        self.act_buf[self.ptr] = act
        self.rew_buf[self.ptr] = rew
        self.done_buf[self.ptr] = done
        self.ptr = (self.ptr + 1) % self.max_size
        self.size = min(self.size + 1, self.max_size)

    def sample_batch(self, batch_size=32):
        idxs = np.random.randint(0, self.size, size=batch_size)
        batch = dict(obs=self.obs_buf[idxs],
                     obs2=self.obs2_buf[idxs],
                     act=self.act_buf[idxs],
                     rew=self.rew_buf[idxs],
                     done=self.done_buf[idxs])
        return {k: torch.as_tensor(v, dtype=torch.float32).to(self.device) for k, v in batch.items()}

class SAC:
    def __init__(self, env_name, actor_critic=MLPActorCritic, ac_kwargs=dict(), seed=0,
            steps_per_epoch=1000, epochs=1000, replay_size=int(1e6), gamma=0.99,
            polyak=0.995, lr=1e-3, alpha=0.2, batch_size=100, start_steps=1000,
            update_after=1000, update_every=10, num_test_episodes=10, max_ep_len=1000,
            save_freq=1):
        """
        Soft Actor-Critic (SAC)


        Args:
            env_fn : A function which creates a copy of the environment.
                The environment must satisfy the OpenAI Gym API.

            actor_critic: The constructor method for a PyTorch Module with an ``act``
                method, a ``pi`` module, a ``q1`` module, and a ``q2`` module.
                The ``act`` method and ``pi`` module should accept batches of
                observations as inputs, and ``q1`` and ``q2`` should accept a batch
                of observations and a batch of actions as inputs. When called,
                ``act``, ``q1``, and ``q2`` should return:

                ===========  ================  ======================================
                Call         Output Shape      Description
                ===========  ================  ======================================
                ``act``      (batch, act_dim)  | Numpy array of actions for each
                                               | observation.
                ``q1``       (batch,)          | Tensor containing one current estimate
                                               | of Q* for the provided observations
                                               | and actions. (Critical: make sure to
                                               | flatten this!)
                ``q2``       (batch,)          | Tensor containing the other current
                                               | estimate of Q* for the provided observations
                                               | and actions. (Critical: make sure to
                                               | flatten this!)
                ===========  ================  ======================================

                Calling ``pi`` should return:

                ===========  ================  ======================================
                Symbol       Shape             Description
                ===========  ================  ======================================
                ``a``        (batch, act_dim)  | Tensor containing actions from policy
                                               | given observations.
                ``logp_pi``  (batch,)          | Tensor containing log probabilities of
                                               | actions in ``a``. Importantly: gradients
                                               | should be able to flow back into ``a``.
                ===========  ================  ======================================

            ac_kwargs (dict): Any kwargs appropriate for the ActorCritic object
                you provided to SAC.

            seed (int): Seed for random number generators.

            steps_per_epoch (int): Number of steps of interaction (state-action pairs)
                for the agent and the environment in each epoch.

            epochs (int): Number of epochs to run and train agent.

            replay_size (int): Maximum length of replay buffer.

            gamma (float): Discount factor. (Always between 0 and 1.)

            polyak (float): Interpolation factor in polyak averaging for target
                networks. Target networks are updated towards main networks
                according to:

                .. math:: \\theta_{\\text{targ}} \\leftarrow
                    \\rho \\theta_{\\text{targ}} + (1-\\rho) \\theta

                where :math:`\\rho` is polyak. (Always between 0 and 1, usually
                close to 1.)

            lr (float): Learning rate (used for both policy and value learning).

            alpha (float): Entropy regularization coefficient. (Equivalent to
                inverse of reward scale in the original SAC paper.)

            batch_size (int): Minibatch size for SGD.

            start_steps (int): Number of steps for uniform-random action selection,
                before running real policy. Helps exploration.

            update_after (int): Number of env interactions to collect before
                starting to do gradient descent updates. Ensures replay buffer
                is full enough for useful updates.

            update_every (int): Number of env interactions that should elapse
                between gradient descent updates. Note: Regardless of how long
                you wait between updates, the ratio of env steps to gradient steps
                is locked to 1.

            num_test_episodes (int): Number of episodes to test the deterministic
                policy at the end of each epoch.

            max_ep_len (int): Maximum length of trajectory / episode / rollout.

            logger_kwargs (dict): Keyword args for EpochLogger.

            save_freq (int): How often (in terms of gap between epochs) to save
                the current policy and value function.

        """

        torch.manual_seed(seed)
        np.random.seed(seed)

        self.env = gym.make(env_name)
        self.test_env = gym.make(env_name)
        self.steps_per_epoch = steps_per_epoch
        self.epochs = epochs
        self.replay_size = replay_size
        self.gamma = gamma
        self.polyak = polyak
        self.lr = lr
        self.alpha = alpha
        self.batch_size = batch_size
        self.start_steps = start_steps
        self.update_after = update_after
        self.update_every = update_every
        self.num_test_episodes = num_test_episodes
        self.max_ep_len = max_ep_len
        self.save_freq =save_freq
        self.obs_dim = self.env.observation_space.shape
        self.act_dim = self.env.action_space.shape[0]

        # Action limit for clamping: critically, assumes all dimensions share the same bound!
        self.act_limit = self.env.action_space.high[0]

        # Create actor-critic module and target networks
        self.ac = actor_critic(self.env.observation_space, self.env.action_space, **ac_kwargs)
        self.ac_targ = deepcopy(self.ac)

        # Freeze target networks with respect to optimizers (only update via polyak averaging)
        for p in self.ac_targ.parameters():
            p.requires_grad = False

        # List of parameters for both Q-networks (save this for convenience)
        self.q_params = itertools.chain(self.ac.q1.parameters(), self.ac.q2.parameters())

        # Experience buffer
        self.replay_buffer = ReplayBuffer(obs_dim=self.obs_dim, act_dim=self.act_dim, size=self.replay_size)

        # Count variables (protip: try to get a feel for how different size networks behave!)
        self.var_counts = tuple(count_vars(module) for module in [self.ac.pi, self.ac.q1, self.ac.q2])

        # Set up optimizers for policy and q-function
        self.pi_optimizer = Adam(self.ac.pi.parameters(), lr=self.lr)
        self.q_optimizer = Adam(self.q_params, lr=self.lr)

        self.device = torch.device('cpu')
        self.reward_model = Estimator(10).to(self.device)
        self.cost_model = Estimator(10).to(self.device)
        # self.reward_model = Estimator(30).to(self.device)
        # self.cost_model = Estimator(30).to(self.device)
        self.lam = 0.5
        # Set up function for computing SAC Q-losses
    def compute_loss_q(self,data):
        o, a, r, o2, d = data['obs'], data['act'], data['rew'], data['obs2'], data['done']

        q1 = self.ac.q1(o, a)
        q2 = self.ac.q2(o, a)

        # Bellman backup for Q functions
        with torch.no_grad():
            # Target actions come from *current* policy
            a2, logp_a2 = self.ac.pi(o2)

            # Target Q-values
            q1_pi_targ = self.ac_targ.q1(o2, a2)
            q2_pi_targ = self.ac_targ.q2(o2, a2)
            q_pi_targ = torch.min(q1_pi_targ, q2_pi_targ)
            backup = r + self.gamma * (1 - d) * (q_pi_targ - self.alpha * logp_a2)

        # MSE loss against Bellman backup
        loss_q1 = ((q1 - backup) ** 2).mean()
        loss_q2 = ((q2 - backup) ** 2).mean()
        loss_q = loss_q1 + loss_q2

        # # Useful info for logging
        # q_info = dict(Q1Vals=q1.detach().numpy(),
        #               Q2Vals=q2.detach().numpy())

        return loss_q

    # Set up function for computing SAC pi loss
    def compute_loss_pi(self, data):
        o = data['obs']
        pi, logp_pi = self.ac.pi(o)
        q1_pi = self.ac.q1(o, pi)
        q2_pi = self.ac.q2(o, pi)
        q_pi = torch.min(q1_pi, q2_pi)

        # Entropy-regularized policy loss
        loss_pi = (self.alpha * logp_pi - q_pi).mean()

        # # Useful info for logging
        # pi_info = dict(LogPi=logp_pi.detach().numpy())

        return loss_pi


    def update(self, data):
        # First run one gradient descent step for Q1 and Q2
        self.q_optimizer.zero_grad()
        loss_q = self.compute_loss_q(data)
        loss_q.backward()
        self.q_optimizer.step()


        # Freeze Q-networks so you don't waste computational effort
        # computing gradients for them during the policy learning step.
        for p in self.q_params:
            p.requires_grad = False

        # Next run one gradient descent step for pi.
        self.pi_optimizer.zero_grad()
        loss_pi= self.compute_loss_pi(data)
        loss_pi.backward()
        self.pi_optimizer.step()

        # Unfreeze Q-networks so you can optimize it at next DDPG step.
        for p in self.q_params:
            p.requires_grad = True


        # Finally, update target networks by polyak averaging.
        with torch.no_grad():
            for p, p_targ in zip(self.ac.parameters(), self.ac_targ.parameters()):
                # NB: We use an in-place operations "mul_", "add_" to update target
                # params, as opposed to "mul" and "add", which would make new tensors.
                p_targ.data.mul_(self.polyak)
                p_targ.data.add_((1 - self.polyak) * p.data)

    def get_action(self, o, deterministic=False):
        return self.ac.act(torch.as_tensor(o, dtype=torch.float32).to(self.device),
                      deterministic)

    def test_agent(self):
        ep_ret = 0

        for j in range(self.num_test_episodes):
            o, _ = self.test_env.reset()
            d = False
            ep_len = 0
            # o, d, ep_ret, ep_len = self.test_env.reset(), False, 0, 0
            while not (d or (ep_len == self.max_ep_len)):
                # Take deterministic actions at test time
                a = self.get_action(o, True)
                o2, r, d, _,_ = self.test_env.step(a)
                # o2, r, _, d, _ = self.test_env.step(a)
                c = 0
                # for k in range(7):
                #     if o[k + 1] > 0.75 or o[k + 1] < -0.75:
                #         c += np.abs(o[k + 1]) - 0.75
                for k in range(0, 2):
                    if a[k] > 0.75 or a[k] < -0.75:
                        c += np.abs(a[k]) - 0.75
                ep_ret += r-c
                ep_len += 1
                o = o2
        return ep_ret/self.num_test_episodes

    def sample(self, reward_model, cost_model):
        num_traj = 3
        env = gym.make('Swimmer-v5')
        observation, _ = env.reset()
        trajs = []
        rewards_truth = []
        rewards_estimate = []
        costs_truth = []
        costs_estimate = []
        qposs = []
        qvels = []
        device = torch.device('cpu')
        for i in range(num_traj):
            traj = []
            r = 0
            r_e = 0
            c = 0
            c_e = 0
            n = 0
            qposs_1 = []
            qvels_1 = []
            episode_over = False
            observation, _ = env.reset()
            while not (episode_over or n == self.max_ep_len):
                action = self.get_action(observation)
                traj.append((observation, action))
                qposs_1.append(deepcopy(env.unwrapped.data.qpos))
                qvels_1.append(deepcopy(env.unwrapped.data.qvel))
                observation_next, reward, terminated, _, _ = env.step(action)
                # observation_next, reward,_,  terminated, _ = env.step(action)
                r += reward
                # for j in range(7):
                #     if observation[j + 1] > 0.75 or observation[j + 1] < -0.75:
                #         c += np.abs(observation[j + 1]) - 0.75
                for k in range(0,2):
                    if action[k] > 0.75 or action[k] < -0.75:
                        c += np.abs(action[k]) - 0.75
                obs_e_in = torch.tensor(observation, dtype=torch.float, device=device)
                a_e_in = torch.tensor(action, dtype=torch.float, device=device)
                input = torch.cat((obs_e_in, a_e_in))
                r_e += reward_model(input)
                c_e += cost_model(input)
                episode_over = terminated
                n += 1
                observation = observation_next
            trajs.append(traj)
            rewards_truth.append(r / self.max_ep_len)
            rewards_estimate.append(r_e / self.max_ep_len)
            costs_truth.append(c / self.max_ep_len)
            costs_estimate.append(c_e / self.max_ep_len)
            qposs.append(qposs_1)
            qvels.append(qvels_1)
        return trajs, rewards_truth, rewards_estimate, costs_truth, costs_estimate, qposs, qvels

    def update_batch(self, batch):
        for i in range(self.batch_size):
            obs = batch['obs'][i]
            a = batch['act'][i]
            obs_in = torch.tensor(obs, dtype=torch.float, device=self.device)
            a_in = torch.tensor(a, dtype=torch.float, device=self.device)
            input = torch.cat((obs_in, a_in))
            r = self.reward_model(input).data.numpy()[0]
            c = self.cost_model(input).data.numpy()[0]
            r_in = r- self.lam * c
            batch['rew'][i] = torch.as_tensor(r_in, dtype=torch.float32).to(self.device)
        return batch

    def learn(self):
        # Main loop: collect experience in env and update/log each epoch
        # Prepare for interaction with environment
        total_steps = self.steps_per_epoch * self.epochs
        start_time = time.time()
        o,_ = self.env.reset()
        ep_ret =0
        ep_len = 0
        traj_dict= {}
        reward_t = {}
        cost_t = {}
        index = 0
        # o, ep_ret, ep_len = self.env.reset(), 0, 0
        # with open('crlhf/walker2/rewards.txt', 'a') as file:
        with open('crlhf/swimmer2/rewards.txt', 'a') as file:
            for t in range(total_steps):
                # Until start_steps have elapsed, randomly sample actions
                # from a uniform distribution for better exploration. Afterwards,
                # use the learned policy.
                if t > self.start_steps:
                    a = self.get_action(o)
                else:
                    a = self.env.action_space.sample()

                # Step the env
                o2, r, d, _, _ = self.env.step(a)
                # o2, r, _, d, _ = self.env.step(a)
                ep_ret += r
                ep_len += 1
                # o_in = torch.tensor(o, dtype=torch.float, device=self.device)
                # a_in = torch.tensor(a, dtype=torch.float, device=self.device)
                # input = torch.cat((o_in, a_in))
                # r_l = self.reward_model(input).data.numpy()[0]
                # r = r_l
                # Ignore the "done" signal if it comes from hitting the time
                # horizon (that is, when it's an artificial terminal signal
                # that isn't based on the agent's state)
                d = False if ep_len == self.max_ep_len else d

                # Store experience to replay buffer
                self.replay_buffer.store(o, a, r, o2, d)

                # Super critical, easy to overlook step: make sure to update
                # most recent observation!
                o = o2

                # End of trajectory handling
                if d or (ep_len == self.max_ep_len):
                    # o, ep_ret, ep_len = self.env.reset(), 0, 0
                    o, _ = self.env.reset()
                    ep_ret = 0
                    ep_len = 0

                # Update handling
                if t >= self.update_after and t % self.update_every == 0:
                    trajs, rewards_truth, rewards_estimate, costs_truth, costs_estimate, qposs, qvels = self.sample(
                                                                                                               self.reward_model,
                                                                                                               self.cost_model)
                    # print(rewards_truth, rewards_estimate, costs_truth)
                    cost_l = [a.detach().cpu().numpy() for a in costs_estimate]
                    lower_g = np.sum(cost_l) / 3
                    self.lam += 1e-2 * lower_g
                    if self.lam < 0:
                        self.lam = 0
                    labels = label(trajs, rewards_truth, costs_truth)
                    r_loss = 0
                    c_loss = 0
                    for i in range(len(labels)):
                        traj_0 = labels[i][0]
                        traj_1 = labels[i][1]
                        reward_compare = labels[i][2]
                        cost_compare = labels[i][3]
                        cost_abs_0 = labels[i][4]
                        cost_abs_1 = labels[i][5]

                        if reward_compare == 0:
                            r_loss += -1 * torch.log(torch.exp(rewards_estimate[traj_0]) / (
                                    torch.exp(rewards_estimate[traj_0]) + torch.exp(rewards_estimate[traj_1]))) / len(
                                labels)
                        else:
                            r_loss += -1 * torch.log(torch.exp(rewards_estimate[traj_1]) / (
                                    torch.exp(rewards_estimate[traj_0]) + torch.exp(rewards_estimate[traj_1]))) / len(
                                labels)

                        if cost_compare == 0:
                            c_loss += -1 * (torch.log(torch.exp(costs_estimate[traj_0]) / (
                                    torch.exp(costs_estimate[traj_0]) + torch.exp(costs_estimate[traj_1])))
                                          + torch.log(1/(1+torch.exp(cost_abs_0 * costs_estimate[traj_0])) )
                                          + torch.log(1/(1+torch.exp(cost_abs_1 * costs_estimate[traj_1]))))/ len(labels)
                        else:
                            c_loss += -1 * (torch.log(torch.exp(costs_estimate[traj_1]) / (
                                    torch.exp(costs_estimate[traj_0]) + torch.exp(costs_estimate[traj_1])))
                                          + torch.log(1/(1+torch.exp(cost_abs_0 * costs_estimate[traj_0])) )
                                          + torch.log(1/(1+torch.exp(cost_abs_1 * costs_estimate[traj_1])))) / len(labels)
                    self.reward_model.opt.zero_grad()
                    r_loss.backward()
                    self.reward_model.opt.step()

                    self.cost_model.opt.zero_grad()
                    c_loss.backward()
                    self.cost_model.opt.step()


                    batch = self.replay_buffer.sample_batch(self.batch_size)
                    batch = self.update_batch(batch)
                    self.update(data=batch)
                    cumu = self.test_agent()
                    file.write(f"{deepcopy(cumu)}\n")
                    file.flush()
                    if t % 100 == 0:
                        torch.save(self.ac.pi.state_dict(), 'crlhf/swimmer2/policy_' + str(t) + '.pt')
        return self.ac.pi

class Estimator(nn.Module):
    def __init__(self,num_input):
        super(Estimator, self).__init__()

        self.fc = nn.Sequential(
            nn.Linear(num_input, 64),
            nn.ReLU(),
            nn.Linear(64, 10),
            nn.ReLU(),
            nn.Linear(10, 1)
        )
        self.mls = nn.MSELoss()
        self.opt = torch.optim.Adam(self.parameters(), lr=0.001)

    def forward(self, x):
        # x = x.to(torch.float64)
        x = self.fc(x)
        x = torch.clamp(x,min =-1,max=1)
        return x


def label(trajs, rewards_truth, costs_truth):
    labels = []
    for i in range(len(trajs)):
        for j in range(i + 1, len(trajs)):
            traj_0_reward = rewards_truth[i]
            traj_1_reward = rewards_truth[j]
            traj_0_cost = costs_truth[i]
            traj_1_cost = costs_truth[j]
            reward_compare = 100
            cost_compare = 100
            cost_abs_0 = 100
            cost_abs_1 = 100
            if traj_0_reward  >= traj_1_reward :
                reward_compare = 0
            else:
                reward_compare = 1
            if traj_0_cost >= traj_1_cost:
                cost_compare = 0
            else:
                cost_compare = 1
            if traj_0_cost > 0:
                cost_abs_0 = 1
            else:
                cost_abs_0 = -1
            if traj_1_cost > 0:
                cost_abs_1 = 1
            else:
                cost_abs_1 = -1
            labels.append((i, j, reward_compare,cost_compare,cost_abs_0,cost_abs_1))
    return labels

if __name__ == '__main__':
    # env_name = 'Walker2d-v5'
    # env_name = 'HalfCheetah-v5'
    env_name = 'Swimmer-v5'
    rlhf = SAC(env_name)
    rlhf.learn()

