import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import torch
import torch.nn as nn
import random

"""
Global constants
"""
SEED = 42
MAX_STEPS = 1000

STATE_DIM = 11
ACTION_DIM = 3
ACTION_HIGH = torch.FloatTensor(np.ones(ACTION_DIM))
ACTION_LOW = - torch.FloatTensor(np.ones(ACTION_DIM))

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

from Networks import Actor, Critic

class ZSPO:
    def __init__(self):
        self.actor = Actor(STATE_DIM, ACTION_DIM)  # Actor Network
        self.actor_perturb = Actor(STATE_DIM, ACTION_DIM)
        self.actor_perturb.load_state_dict(self.actor.state_dict())

        self.lr = 0.05
        self.max_grad_norm = 0.5
        self.perturbation_dist = 0.3

        self.optim = optim.AdamW(self.actor.parameters(), lr=self.lr)
        self.perturbation_vectors = []


    def load_model(self, pth):
        self.actor.load_state_dict(torch.load(pth, weights_only=True))
        self.actor_perturb.load_state_dict(self.actor.state_dict())

    def select_action(self, state):
        state = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
        with torch.no_grad():
            mean, std = self.actor(state)
        action = torch.tanh(mean)
        return action.squeeze(0).numpy()

    def perturb_actor(self):
        """
        :param
        :return:
        """
        perturb_vec = []
        self.actor_perturb.load_state_dict(self.actor.state_dict())

        with torch.no_grad():
            for name, module in self.actor_perturb.named_modules():
                if isinstance(module, nn.Linear):
                    for param in module.parameters():
                        noise = torch.randn_like(param) * self.perturbation_dist
                        param.add_(noise)
                        perturb_vec.append(noise)

            for name, param in self.actor_perturb.named_parameters():
                if '.' not in name:
                    if not any(name.startswith(pn + '.') for pn, _ in self.actor_perturb.named_modules()):
                        noise = torch.randn_like(param) * self.perturbation_dist
                        param.add_(noise)
                        perturb_vec.append(noise)

        self.perturbation_vectors=perturb_vec.copy()
        return perturb_vec.copy()

    def train(self, prob):
        """
        :param prob: probability
        :return:
        """
        perturb_vec = self.perturbation_vectors.copy()

        gradient_estimator = []
        with torch.no_grad():
            for module in self.actor.modules():
                if isinstance(module, nn.Linear):
                    for param in module.parameters():
                        gradient_estimator.append(torch.zeros_like(param))
            for name, param in self.actor.named_parameters():
                if '.' not in name:
                    if not any(name.startswith(pn + '.') for pn, _ in self.actor.named_modules()):
                        gradient_estimator.append(torch.zeros_like(param))

        for grad, vec in zip(gradient_estimator, perturb_vec):
            grad += vec * (2 * prob - 1)

        if prob > 0.5:
            for _ in range(1):
                self.optim.zero_grad()
                with torch.no_grad():
                    for module in self.actor.modules():
                        if isinstance(module, nn.Linear):
                            for param in module.parameters():
                                param.grad = - 1.0 * gradient_estimator.pop(0)
                    for name, param in self.actor.named_parameters():
                        if '.' not in name:
                            if not any(name.startswith(pn + '.') for pn, _ in self.actor.named_modules()):
                                param.grad = -1.0 * gradient_estimator.pop(0)

                # nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm)
                self.optim.step()
        self.perturbation_vectors = []
        return

class ZPG:
    def __init__(self, expertise = 1):
        self.actor = Actor(STATE_DIM, ACTION_DIM)  # Actor Network
        self.actor_perturb = Actor(STATE_DIM, ACTION_DIM)
        self.actor_perturb.load_state_dict(self.actor.state_dict())

        self.lr = 0.03
        self.max_grad_norm = 0.5
        self.perturbation_dist = 0.3
        self.expertise = expertise
        self.link_function = 'BT'

        self.optim = optim.AdamW(self.actor.parameters(), lr=self.lr)

        self.perturbation_vectors = []


    def load_model(self, pth):
        self.actor.load_state_dict(torch.load(pth, weights_only=True))
        self.actor_perturb.load_state_dict(self.actor.state_dict())

    def select_action(self, state):
        state = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
        with torch.no_grad():
            mean, std = self.actor(state)
        action = torch.tanh(mean)
        return action.squeeze(0).numpy()

    def perturb_actor(self):
        """
        :param
        :return:
        """
        perturb_vec = []
        self.actor_perturb.load_state_dict(self.actor.state_dict())

        with torch.no_grad():
            for name, module in self.actor_perturb.named_modules():
                if isinstance(module, nn.Linear):
                    for param in module.parameters():
                        noise = torch.randn_like(param) * self.perturbation_dist
                        param.add_(noise)
                        perturb_vec.append(noise)

            for name, param in self.actor_perturb.named_parameters():
                if '.' not in name:
                    if not any(name.startswith(pn + '.') for pn, _ in self.actor_perturb.named_modules()):
                        noise = torch.randn_like(param) * self.perturbation_dist
                        param.add_(noise)
                        perturb_vec.append(noise)

        self.perturbation_vectors = perturb_vec.copy()
        return perturb_vec.copy()

    def train(self, prob):
        """
        :param preference: list of 0-1s
        :param last_layer_only: 1 if only perturb the last layer
        :return:
        """
        perturb_vec = self.perturbation_vectors.copy()

        gradient_estimator = []
        with torch.no_grad():
            for module in self.actor.modules():
                if isinstance(module, nn.Linear):
                    for param in module.parameters():
                        gradient_estimator.append(torch.zeros_like(param))
            for name, param in self.actor.named_parameters():
                if '.' not in name:
                    if not any(name.startswith(pn + '.') for pn, _ in self.actor.named_modules()):
                        gradient_estimator.append(torch.zeros_like(param))

        reward_diff = 0
        if self.link_function == 'BT':
            reward_diff = np.log((prob + 1e-15) / (1 - prob + 1e-15)) / self.expertise
        else:
            reward_diff = np.log((prob + 1e-15) / (1 - prob + 1e-15)) / self.expertise

        for grad, vec in zip(gradient_estimator, perturb_vec):
            grad += reward_diff / np.pow(self.perturbation_dist, 2) * vec

        for _ in range(1):
            self.optim.zero_grad()
            with torch.no_grad():
                for module in self.actor.modules():
                    if isinstance(module, nn.Linear):
                        for param in module.parameters():
                            param.grad = - 1.0 * gradient_estimator.pop(0)
                for name, param in self.actor.named_parameters():
                    if '.' not in name:
                        if not any(name.startswith(pn + '.') for pn, _ in self.actor.named_modules()):
                            param.grad = -1.0 * gradient_estimator.pop(0)

            # nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm)
            self.optim.step()

        self.perturbation_vectors = []
        return

class DPO:
    def __init__(self):
        self.actor = Actor(STATE_DIM, ACTION_DIM)  # Actor Network
        self.actor_ref = Actor(STATE_DIM, ACTION_DIM)
        self.actor_ref.load_state_dict(self.actor.state_dict())

        self.lr = 0.001
        self.max_grad_norm = 0.5
        self.dpo_epochs = 1

        self.optim = optim.AdamW(self.actor.parameters(), lr=self.lr)

    def load_model(self, pth):
        self.actor.load_state_dict(torch.load(pth, weights_only=True))
        self.actor_ref.load_state_dict(self.actor.state_dict())

    def select_action(self, state):
        state = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
        with torch.no_grad():
            mean, std = self.actor(state)
        dist = torch.distributions.Normal(mean, std)
        pre_tanh = dist.rsample()  # rsample for gradients through mean/std
        action = torch.tanh(pre_tanh)

        # log |det d(tanh)/dx| = sum log(1 - tanh(x)^2)
        logp = dist.log_prob(pre_tanh).sum(-1) - torch.log1p(-action.pow(2) + 1e-6).sum(-1)
        return action.squeeze(0).numpy(), logp.item()

    def evaluate(self, states, actions):
        mean, std = self.actor(states)
        dist = torch.distributions.Normal(mean, std)
        x = torch.clamp(actions, -0.999999, 0.999999)
        pre_tanh = torch.atanh(x)
        logp = dist.log_prob(pre_tanh).sum(-1) - torch.log1p(-x.pow(2) + 1e-6).sum(-1)
        entropy = dist.entropy().sum(dim=-1)
        return logp, entropy

    def train(self, memory):
        batch_size = len(memory['probs'])

        all_states_0, all_actions_0, all_log_probs_0 = memory['states_0'], memory['actions_0'], memory['log_probs_0']
        all_states_1, all_actions_1, all_log_probs_1 = memory['states_1'], memory['actions_1'], memory['log_probs_1']
        probs = memory['probs']

        for _ in range(self.dpo_epochs):
            losses = []
            for i in range(batch_size):
                states_0, actions_0, log_probs_0 = all_states_0[i], all_actions_0[i], all_log_probs_0[i]
                states_1, actions_1, log_probs_1 = all_states_1[i], all_actions_1[i], all_log_probs_1[i]
                prob = probs[i]

                states_0 = torch.tensor(np.array(states_0), dtype = torch.float32)
                actions_0 = torch.tensor(np.array(actions_0), dtype = torch.float32)
                old_log_probs_0 = torch.tensor(np.array(log_probs_0), dtype = torch.float32)

                states_1 = torch.tensor(np.array(states_1), dtype=torch.float32)
                actions_1 = torch.tensor(np.array(actions_1), dtype=torch.float32)
                old_log_probs_1 = torch.tensor(np.array(log_probs_1), dtype=torch.float32)

                new_log_probs_0, entropy = self.evaluate(states_0, actions_0)
                new_log_probs_1, entropy = self.evaluate(states_1, actions_1)

                diff = torch.sum(new_log_probs_1 - old_log_probs_1) - torch.sum(new_log_probs_0 - old_log_probs_0)
                prob_loss = 1 / (1 + torch.exp(- 0.001 * diff))

                loss = - (prob * torch.log(prob_loss) + (1-prob) * torch.log(1 - prob_loss))
                losses.append(loss)

            loss_actor = torch.stack(losses).mean()

            self.optim.zero_grad()
            loss_actor.backward()
            # nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm)
            self.optim.step()

        # if online DPO
        self.actor_ref.load_state_dict(self.actor.state_dict())

        return loss_actor.item()

class ES:
    def __init__(self):
        self.actor = Actor(STATE_DIM, ACTION_DIM)  # Actor Network
        self.actor_perturb = Actor(STATE_DIM, ACTION_DIM)
        self.actor_perturb.load_state_dict(self.actor.state_dict())

        self.perturbation_dist = 0.3

        self.perturbation_vectors = []

    def load_model(self, pth):
        self.actor.load_state_dict(torch.load(pth, weights_only=True))
        self.actor_perturb.load_state_dict(self.actor.state_dict())

    def select_action(self, state):
        state = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
        with torch.no_grad():
            mean, std = self.actor(state)
        action = torch.tanh(mean)
        return action.squeeze(0).numpy()

    def perturb_actor(self):
        """
        :param
        :return:
        """
        perturb_vec = []
        self.actor_perturb.load_state_dict(self.actor.state_dict())

        with torch.no_grad():
            for name, module in self.actor_perturb.named_modules():
                if isinstance(module, nn.Linear):
                    for param in module.parameters():
                        noise = torch.randn_like(param) * self.perturbation_dist
                        param.add_(noise)
                        perturb_vec.append(noise)

            for name, param in self.actor_perturb.named_parameters():
                if '.' not in name:
                    if not any(name.startswith(pn + '.') for pn, _ in self.actor_perturb.named_modules()):
                        noise = torch.randn_like(param) * self.perturbation_dist
                        param.add_(noise)
                        perturb_vec.append(noise)

        self.perturbation_vectors = perturb_vec.copy()
        return perturb_vec.copy()

    def train(self, prob):
        if prob > 0.5:
            self.actor.load_state_dict(self.actor_perturb.state_dict())
            self.actor_perturb.load_state_dict(self.actor.state_dict())

        self.perturbation_vectors = []
        return

class PPO:
    def __init__(self):
        self.actor = Actor(STATE_DIM,ACTION_DIM)                       # Actor Network
        self.critic = Critic(STATE_DIM)                                 # Critic Network

        self.lr = 1e-3
        self.gamma = 0.99
        self.gae_lambda = 0.95
        self.clip_eps = 0.2
        self.max_grad_norm = 0.5
        self.entropy_coef = 0.00
        self.ppo_epochs = 1

        self.actor_optim = optim.AdamW(self.actor.parameters(), lr=self.lr)
        self.critic_optim = optim.AdamW(self.critic.parameters(), lr=self.lr)

    def load_model(self, pth):
        self.actor.load_state_dict(torch.load(pth, weights_only=True))

    def select_action(self, state):
        state = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
        with torch.no_grad():
            mean, std = self.actor(state)
        dist = torch.distributions.Normal(mean, std)
        pre_tanh = dist.rsample()  # rsample for gradients through mean/std
        action = torch.tanh(pre_tanh)

        # log |det d(tanh)/dx| = sum log(1 - tanh(x)^2)
        logp = dist.log_prob(pre_tanh).sum(-1) - torch.log1p(-action.pow(2) + 1e-6).sum(-1)
        return action.squeeze(0).numpy(), logp.item()

    def evaluate(self, states, actions):
        mean, std = self.actor(states)
        dist = torch.distributions.Normal(mean, std)
        x = torch.clamp(actions, -0.999999, 0.999999)
        pre_tanh = torch.atanh(x)
        logp = dist.log_prob(pre_tanh).sum(-1) - torch.log1p(-x.pow(2) + 1e-6).sum(-1)
        entropy = dist.entropy().sum(dim=-1)
        values = self.critic(states).squeeze(-1)
        return logp, entropy, values

    def compute_gae(self, rewards, dones, values, next_value):
        advantages = []
        gae = 0
        values = values + [next_value]
        for step in reversed(range(len(rewards))):
            delta = rewards[step] + self.gamma * values[step + 1] * (1 - dones[step]) - values[step]
            gae = delta + self.gamma * self.gae_lambda * (1 - dones[step]) * gae
            advantages.insert(0, gae)
        returns = [adv + val for adv, val in zip(advantages, values[:-1])]
        return advantages, returns

    def train(self, memory):
        batch_size = len(memory['states'])
        states = torch.tensor(np.array(memory['states']), dtype=torch.float32)
        actions = torch.tensor(np.array(memory['actions']), dtype=torch.float32)
        old_log_probs = torch.tensor(np.array(memory['log_probs']), dtype=torch.float32)
        returns = torch.tensor(np.array(memory['returns']), dtype=torch.float32)
        advantages = torch.tensor(np.array(memory['advantages']), dtype=torch.float32)
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

        for epoch in range(self.ppo_epochs):
            log_probs, entropy, values = self.evaluate(states, actions)

            # Check for NaN values
            if torch.isnan(log_probs).any() or torch.isnan(values).any():
                print(f"Warning: NaN detected at epoch {epoch}")
                print(f"NaN in log_probs: {torch.isnan(log_probs).sum().item()}")
                print(f"NaN in values: {torch.isnan(values).sum().item()}")
                break

            ratios = torch.exp(log_probs - old_log_probs)
            surr1 = ratios * advantages
            surr2 = torch.clamp(ratios, 1 - self.clip_eps, 1 + self.clip_eps) * advantages
            actor_loss = -torch.min(surr1, surr2).mean() - self.entropy_coef * entropy.mean()
            critic_loss = F.mse_loss(values, returns)

            self.actor_optim.zero_grad()
            actor_loss.backward()
            # nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm)
            self.actor_optim.step()

            self.critic_optim.zero_grad()
            critic_loss.backward()
            # nn.utils.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm)
            self.critic_optim.step()

        return actor_loss.item(), critic_loss.item()
