# Based on the TD3 algorithm from https://github.com/sfujim/TD3/blob/master/TD3.py
import copy
import numpy as np
import random
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

# Normalise the inputs (states and goals)
clip_range = 5
clip_obs = 200
clip_return = 50

def process_inputs(o, g, o_mean, o_std, g_mean, g_std):
    o_clip = np.clip(o, -clip_obs, clip_obs)
    g_clip = np.clip(g, -clip_obs, clip_obs)
    o_norm = np.clip((o_clip - o_mean) / (o_std + 1e-6), -clip_range, clip_range)
    g_norm = np.clip((g_clip - g_mean) / (g_std + 1e-6), -clip_range, clip_range)
    inputs = np.concatenate([o_norm, g_norm], axis=1)
    return inputs

# Define Vectorised linear for neural network (from https://github.com/tinkoff-ai/CORL/blob/main/algorithms/offline/sac_n.py)
class VectorisedLinear(nn.Module):
    def __init__(self, in_features, out_features, ensemble_size):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.ensemble_size = ensemble_size

        self.weight = nn.Parameter(torch.empty(ensemble_size, in_features, out_features))
        self.bias = nn.Parameter(torch.empty(ensemble_size, 1, out_features))

        self.reset_parameters()

    def reset_parameters(self):
        for layer in range(self.ensemble_size):
            nn.init.kaiming_uniform_(self.weight[layer], a=math.sqrt(5))

        fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight[0])
        bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
        nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # input: [ensemble_size, batch_size, input_size]
        # weight: [ensemble_size, input_size, out_size]
        # out: [ensemble_size, batch_size, out_size]
        return x @ self.weight + self.bias

# Define the neural network for an ensemble of critics
class VectorisedCritic(nn.Module):
    def __init__(self, state_dim, action_dim, num_critics=2, hidden_dim=(256,256)):
        super(VectorisedCritic, self).__init__()

        self.l1 = VectorisedLinear(state_dim+action_dim, hidden_dim[0], num_critics)
        self.l2 = VectorisedLinear(hidden_dim[0], hidden_dim[1], num_critics)
        self.l3 = VectorisedLinear(hidden_dim[1], 1, num_critics)
        self.num_critics = num_critics

    def forward(self, state, action):
        state_action = torch.cat([state, action], dim=-1)
        state_action = state_action.unsqueeze(0).repeat_interleave(self.num_critics, dim=0)

        q_values = F.relu(self.l1(state_action))
        q_values = F.relu(self.l2(q_values))
        q_values = self.l3(q_values)

        return q_values.squeeze(-1)

# Define the neural network for an actor
class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, max_action, hidden_dim=(256, 256)):
        super(Actor, self).__init__()

        self.l1 = nn.Linear(state_dim, hidden_dim[0])
        self.l2 = nn.Linear(hidden_dim[0], hidden_dim[1])
        self.l3 = nn.Linear(hidden_dim[1], action_dim)
        self.max_action = max_action

    def forward(self, state):
        a = F.relu(self.l1(state))
        a = F.relu(self.l2(a))
        a = self.max_action * torch.tanh(self.l3(a))  # scale the actions by max_action

        return a

class Agent(object):
    def __init__(self, state_dim, goal_dim, action_dim, max_action, hidden_dim=(256, 256), method="SPReDP", ensemble_size = 10,
                 lambda1=1, lambda2=1, batch_size_buffer=1024, batch_size_demo=128, gamma=0.98, tau=0.005, lr=1e-3,
                 policy_noise=0.2, noise_clip=0.5, policy_freq=2, device="cuda:0"):

        self.actor = Actor(state_dim + goal_dim, action_dim, max_action, hidden_dim).to(device)
        self.actor_target = copy.deepcopy(self.actor)
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=lr)

        self.critic = VectorisedCritic(state_dim + goal_dim, action_dim, ensemble_size, hidden_dim).to(device)
        self.critic_target = copy.deepcopy(self.critic)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=lr)

        self.state_dim = state_dim
        self.action_dim = action_dim
        self.max_action = max_action
        self.method = method
        self.ensemble_size = ensemble_size
        self.lambda1 = lambda1
        self.lambda2 = lambda2
        self.batch_size_buffer = batch_size_buffer
        self.batch_size_demo = batch_size_demo
        self.gamma = gamma
        self.tau = tau
        self.policy_noise = policy_noise
        self.noise_clip = noise_clip
        self.policy_freq = policy_freq
        self.device = device

        self.weight_history = []
        self.accept_history = []
        self.total_it = 0

    def choose_action(self, state):
        with torch.no_grad():
            state = torch.Tensor(state.reshape(1, -1)).to(self.device)
            action = self.actor(state)

        return action.cpu().numpy().flatten()

    def train(self, replay_buffer, demos, normalizers=(0, 1, 0, 1), iterations=2):
        for it in range(iterations):
            self.total_it += 1
            # Sample a mini-batch from the replay buffer
            minibatch = random.sample(replay_buffer, self.batch_size_buffer)
            state = np.array([d[0] for d in minibatch])
            action = torch.Tensor(np.array([d[1] for d in minibatch])).to(self.device)
            reward = torch.Tensor(np.array([d[2] for d in minibatch])).to(self.device)
            next_state = np.array([d[3] for d in minibatch])
            goal = np.array([d[4] for d in minibatch])
            done = torch.Tensor(np.array([d[5] for d in minibatch])).to(self.device)
            input = process_inputs(state, goal, o_mean=normalizers[0], o_std=normalizers[1],
                                   g_mean=normalizers[2], g_std=normalizers[3])
            input = torch.Tensor(input).to(self.device)
            next_input = process_inputs(next_state, goal, o_mean=normalizers[0], o_std=normalizers[1],
                                        g_mean=normalizers[2], g_std=normalizers[3])
            next_input = torch.Tensor(next_input).to(self.device)

            # Sample a mini-batch from the demonstration buffer
            demos_minibatch = random.sample(demos, self.batch_size_demo)
            demos_state = np.array([d[0] for d in demos_minibatch])
            demos_action = torch.Tensor(np.array([d[1] for d in demos_minibatch])).to(self.device)
            demos_reward = torch.Tensor(np.array([d[2] for d in demos_minibatch])).to(self.device)
            demos_next_state = np.array([d[3] for d in demos_minibatch])
            demos_goal = np.array([d[4] for d in demos_minibatch])
            demos_done = torch.Tensor(np.array([d[5] for d in demos_minibatch])).to(self.device)
            demos_input = process_inputs(demos_state, demos_goal, o_mean=normalizers[0], o_std=normalizers[1],
                                   g_mean=normalizers[2], g_std=normalizers[3])
            demos_input = torch.Tensor(demos_input).to(self.device)
            demos_next_input = process_inputs(demos_next_state, demos_goal, o_mean=normalizers[0], o_std=normalizers[1],
                                        g_mean=normalizers[2], g_std=normalizers[3])
            demos_next_input = torch.Tensor(demos_next_input).to(self.device)

            # Critic updates
            with torch.no_grad():
                noise = (torch.randn_like(action) * self.policy_noise).clamp(-self.noise_clip, self.noise_clip)
                next_action = (self.actor_target(next_input) + noise).clamp(-self.max_action, self.max_action)
                target_q = self.critic_target(next_input, next_action)
                # Take the minimum of a subset of size 2 for the target to overcome the estimation bias during critic updates
                indices = torch.randperm(target_q.size(0))
                target_q = target_q[indices[:2]]
                target_q = target_q.min(0)[0]
                target_q = reward + (1 - done) * self.gamma * target_q
            current_q = self.critic(input, action)

            with torch.no_grad():
                demos_noise = (torch.randn_like(demos_action) * self.policy_noise).clamp(-self.noise_clip, self.noise_clip)
                demos_next_action = (self.actor_target(demos_next_input) + demos_noise).clamp(-self.max_action, self.max_action)
                demos_target_q = self.critic_target(demos_next_input, demos_next_action)
                demos_indices = torch.randperm(demos_target_q.size(0))
                demos_target_q = demos_target_q[demos_indices[:2]]
                demos_target_q = demos_target_q.min(0)[0]
                demos_target_q = demos_reward + (1 - demos_done) * self.gamma * demos_target_q
            demos_current_q = self.critic(demos_input, demos_action)

            # Both mini-batches from the replay buffer and demonstration buffer are used for critic updates
            critic_loss = (F.mse_loss(current_q, target_q.unsqueeze(0).expand_as(current_q))
                           + F.mse_loss(demos_current_q, demos_target_q.unsqueeze(0).expand_as(demos_current_q)))
            self.critic_optimizer.zero_grad()
            critic_loss.backward()
            self.critic_optimizer.step()

            # Actor updates
            # The actor is less frequently updated: once per policy_freq updates of critics
            if self.total_it % self.policy_freq == 0:
                policy_actions = self.actor(input)
                # Use the mean of ensemble to evaluate the policy action for actor updates
                Q = torch.mean(self.critic(input, policy_actions), dim=0)

                # Calculate the behaviour cloning loss corresponding to different methods
                demos_policy_actions = self.actor(demos_input)
                # EnsQ-filter uses the mean of ensemble as the criterion for binary imitation decisions
                if self.method == "EnsQfilter":
                    demos_Q_set = self.critic(demos_input, demos_policy_actions)
                    Q_dem_set = self.critic(demos_input, demos_action)
                    demos_Q = torch.mean(demos_Q_set, dim=0)
                    Q_dem = torch.mean(Q_dem_set, dim=0)
                    mask = torch.ge(Q_dem, demos_Q).reshape(self.batch_size_demo, 1).repeat(1, self.action_dim)
                    BC_loss = F.mse_loss(torch.masked_select(demos_policy_actions, mask), torch.masked_select(demos_action, mask))
                    # Record the percentage of accepted demonstrations in the mini-batch
                    percent_accept = mask.sum(dim=0)[0].detach().cpu().item() / self.batch_size_demo
                    self.accept_history.append(percent_accept)

                # SPReD-P uses probabilistic weights for smooth imitation
                if self.method == "SPReDP":
                    demos_Q_set = self.critic(demos_input, demos_policy_actions)
                    Q_dem_set = self.critic(demos_input, demos_action)
                    demos_Q_mean = torch.mean(demos_Q_set, dim=0)
                    Q_dem_mean = torch.mean(Q_dem_set, dim=0)
                    demos_Q_std = torch.std(demos_Q_set, dim=0)
                    Q_dem_std = torch.std(Q_dem_set, dim=0)
                    z_score = (Q_dem_mean-demos_Q_mean)/torch.sqrt(Q_dem_std**2 + demos_Q_std**2)
                    probweights = 0.5 * (1+torch.erf(z_score/torch.sqrt(torch.tensor(2.0))))
                    se = F.mse_loss(demos_policy_actions, demos_action, reduction='none')
                    se = torch.mean(se, dim=1)
                    weighted_se = se * probweights
                    BC_loss = torch.mean(weighted_se)
                    # Probabilistic weight for the first sample in each mini-batch is recorded to have a overview of its behavior
                    weight = probweights[0]
                    self.weight_history.append(weight.item())

                # SPReD-E uses exponential weights for smooth imitation
                if self.method == "SPReDE":
                    demos_Q_set = self.critic(demos_input, demos_policy_actions)
                    Q_dem_set = self.critic(demos_input, demos_action)
                    demos_Q_mean = torch.mean(demos_Q_set, dim=0)
                    Q_dem_mean = torch.mean(Q_dem_set, dim=0)
                    demos_Q_qt=torch.quantile(demos_Q_set, 0.75, dim=0)-torch.quantile(demos_Q_set, 0.25, dim=0)
                    Q_dem_qt=torch.quantile(Q_dem_set, 0.75, dim=0)-torch.quantile(Q_dem_set, 0.25, dim=0)
                    beta = (demos_Q_qt+Q_dem_qt)/2*10
                    expweights = torch.exp((Q_dem_mean-demos_Q_mean)/beta)-1
                    expweights = torch.clamp(expweights, min=0, max=1)
                    se = F.mse_loss(demos_policy_actions, demos_action, reduction='none')
                    se = torch.mean(se, dim=1)
                    weighted_se = se * expweights
                    BC_loss = torch.mean(weighted_se)
                    # Exponential weights for the first sample in each mini-batch is recorded to have a overview of its behavior
                    weight = expweights[0]
                    self.weight_history.append(weight.item())

                # Nonparametric methods pairwise and crosswise compare Q-values,
                # which validata the Gaussian assumption in our SPReD-P
                if self.method == "Nonpara_pairwise":
                    demos_Q_set = self.critic(demos_input, demos_policy_actions)
                    Q_dem_set = self.critic(demos_input, demos_action)
                    count = torch.sum(Q_dem_set > demos_Q_set, dim=0)
                    probweights = count/self.ensemble_size
                    se = F.mse_loss(demos_policy_actions, demos_action, reduction='none')
                    se = torch.mean(se, dim=1)
                    weighted_se = se * probweights
                    BC_loss = torch.mean(weighted_se)
                    weight = probweights[0]
                    self.weight_history.append(weight.item())

                if self.method == "Nonpara_cross":
                    demos_Q_set = self.critic(demos_input, demos_policy_actions)
                    Q_dem_set = self.critic(demos_input, demos_action)
                    count = torch.sum(Q_dem_set.unsqueeze(1) > demos_Q_set.unsqueeze(0), dim=(0, 1))
                    total = self.ensemble_size * self.ensemble_size
                    probweights = count/total
                    se = F.mse_loss(demos_policy_actions, demos_action, reduction='none')
                    se = torch.mean(se, dim=1)
                    weighted_se = se * probweights
                    BC_loss = torch.mean(weighted_se)
                    weight = probweights[0]
                    self.weight_history.append(weight.item())
                
                actor_loss = -self.lambda1 * Q.mean() + self.lambda2 * BC_loss
                self.actor_optimizer.zero_grad()
                actor_loss.backward()
                self.actor_optimizer.step()

                # Update the target networks via Polyak averaging
                for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
                    target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

                for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
                    target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
