import gymnasium as gym
import torch.optim as optim
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import json
import random
from torch.distributions import Normal, Categorical
from collections import deque
import matplotlib.pyplot as plt

"""
Global constants
"""
SEED = 42
MAX_STEPS = 500

STATE_DIM = 4
ACTION_DIM = 2

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

"""
Import Self-Defined Module
"""

from Networks import Actor, Critic


class DPO:
    def __init__(self):
        self.actor = Actor(STATE_DIM, ACTION_DIM)  # Actor Network
        self.actor_ref = Actor(STATE_DIM, ACTION_DIM)
        self.actor_ref.load_state_dict(self.actor.state_dict())

        self.lr = 0.001
        self.max_grad_norm = 0.5
        self.dpo_epochs = 1

        self.actor_optim = optim.AdamW(self.actor.parameters(), lr=self.lr)

    def load_model(self, pth):
        self.actor.load_state_dict(torch.load(pth, weights_only=True))
        self.actor_ref.load_state_dict(self.actor.state_dict())

    def select_action(self, state):
        state = torch.tensor(state, dtype=torch.float32)
        with torch.no_grad():
            prob = self.actor(state)
        dist = Categorical(prob)
        action = dist.sample()
        log_prob = dist.log_prob(action)
        return action.item(), log_prob.item()

    def evaluate(self, states, actions):
        prob = self.actor(states)
        dist = Categorical(prob)
        log_probs = dist.log_prob(actions)
        entropy = dist.entropy()
        return log_probs, entropy

    def train(self, memory):
        batch_size = len(memory['probs'])

        all_states_0, all_actions_0, all_log_probs_0 = memory['states_0'], memory['actions_0'], memory['log_probs_0']
        all_states_1, all_actions_1, all_log_probs_1 = memory['states_1'], memory['actions_1'], memory['log_probs_1']
        probs = memory['probs']

        for _ in range(self.dpo_epochs):
            losses = []
            for i in range(batch_size):
                states_0, actions_0, log_probs_0 = all_states_0[i], all_actions_0[i], all_log_probs_0[i]
                states_1, actions_1, log_probs_1 = all_states_1[i], all_actions_1[i], all_log_probs_1[i]
                prob = probs[i]

                states_0 = torch.tensor(np.array(states_0), dtype = torch.float32)
                actions_0 = torch.tensor(np.array(actions_0), dtype = torch.int64)
                old_log_probs_0 = torch.tensor(np.array(log_probs_0), dtype = torch.float32)

                states_1 = torch.tensor(np.array(states_1), dtype=torch.float32)
                actions_1 = torch.tensor(np.array(actions_1), dtype=torch.int64)
                old_log_probs_1 = torch.tensor(np.array(log_probs_1), dtype=torch.float32)

                new_log_probs_0, entropy = self.evaluate(states_0, actions_0)
                new_log_probs_1, entropy = self.evaluate(states_1, actions_1)

                # isolate the effect of KL
                diff = torch.sum(new_log_probs_1 - old_log_probs_1) - torch.sum(new_log_probs_0 - old_log_probs_0)
                prob_loss = 1 / (1 + torch.exp(- 0.01 * diff))

                loss = - (prob * torch.log(prob_loss) + (1-prob) * torch.log(1 - prob_loss))
                losses.append(loss)

            loss_actor = torch.stack(losses).mean()

            self.actor_optim.zero_grad()
            loss_actor.backward()
            # nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm)
            self.actor_optim.step()

        # if online DPO
        self.actor_ref.load_state_dict(self.actor.state_dict())

        return loss_actor.item()

class ZPG:
    def __init__(self, expertise=0.01):
        self.actor = Actor(STATE_DIM, ACTION_DIM)  # Actor Network
        self.actor_perturb = Actor(STATE_DIM, ACTION_DIM)
        self.actor_perturb.load_state_dict(self.actor.state_dict())

        self.lr = 0.02
        self.max_grad_norm = 0.5
        self.perturbation_dist = 0.3
        self.expertise = expertise
        self.link_function = 'BT'

        self.optim = optim.AdamW(self.actor.parameters(), lr=self.lr)
        self.perturbation_vectors = []

    def load_model(self, pth):
        self.actor.load_state_dict(torch.load(pth, weights_only=True))
        self.actor_perturb.load_state_dict(self.actor.state_dict())

    def select_action(self, state):
        state = torch.tensor(state, dtype=torch.float32)
        with torch.no_grad():
            prob = self.actor(state)
        action = torch.argmax(prob).item()
        return action

    def perturb_actor(self):
        """
        :param
        :return:
        """
        perturb_vec = []
        self.actor_perturb.load_state_dict(self.actor.state_dict())

        with torch.no_grad():
            for name, module in self.actor_perturb.named_modules():
                if isinstance(module, nn.Linear):
                    for param in module.parameters():
                        noise = torch.randn_like(param) * self.perturbation_dist
                        param.add_(noise)
                        perturb_vec.append(noise)

        self.perturbation_vector = perturb_vec.copy()
        return perturb_vec.copy()

    def train(self, prob):
        """
        :param preference: list of 0-1s
        :return:
        """
        perturb_vec = self.perturbation_vector.copy()

        gradient_estimator = []
        with torch.no_grad():
            for module in self.actor.modules():
                if isinstance(module, nn.Linear):
                    for param in module.parameters():
                        gradient_estimator.append(torch.zeros_like(param))

        if self.link_function == 'BT':
            reward_diff = np.log((prob + 1e-15) / (1 - prob + 1e-15)) / self.expertise
        else:
            reward_diff = np.log((prob + 1e-15) / (1 - prob + 1e-15)) / self.expertise

        for grad, vec in zip(gradient_estimator, perturb_vec):
            grad += reward_diff / np.pow(self.perturbation_dist, 2) * vec

        for _ in range(1):
            self.optim.zero_grad()
            with torch.no_grad():
                for module in self.actor.modules():
                    if isinstance(module, nn.Linear):
                        for param in module.parameters():
                            param.grad = - 1.0 * gradient_estimator.pop(0)

            # nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm)
            self.optim.step()

        self.perturbation_vectors = []
        return

class ZSPO:
    def __init__(self):
        self.actor = Actor(STATE_DIM, ACTION_DIM)  # Actor Network
        self.actor_perturb = Actor(STATE_DIM, ACTION_DIM)
        self.actor_perturb.load_state_dict(self.actor.state_dict())

        self.lr = 0.03
        self.max_grad_norm = 0.5
        self.perturbation_dist = 0.3

        self.optim = optim.AdamW(self.actor.parameters(), lr=self.lr)
        self.perturbation_vector = []

    def load_model(self, pth):
        self.actor.load_state_dict(torch.load(pth, weights_only=True))
        self.actor_perturb.load_state_dict(self.actor.state_dict())

    def select_action(self, state):
        state = torch.tensor(state, dtype=torch.float32)
        with torch.no_grad():
            prob = self.actor(state)
        action = torch.argmax(prob).item()
        return action

    def perturb_actor(self):
        """
        :param
        :return:
        """
        perturb_vec = []
        self.actor_perturb.load_state_dict(self.actor.state_dict())

        with torch.no_grad():
            for name, module in self.actor_perturb.named_modules():
                if isinstance(module, nn.Linear):
                    for param in module.parameters():
                        noise = torch.randn_like(param)* self.perturbation_dist
                        param.add_(noise)
                        perturb_vec.append(noise)

        self.perturbation_vector = perturb_vec.copy()
        return perturb_vec.copy()

    def train(self, prob):
        """
        :param preference: list of 0-1s
        :param last_layer_only: 1 if only perturb the last layer
        :return:
        """
        perturb_vec = self.perturbation_vector.copy()

        gradient_estimator = []
        with torch.no_grad():
            for module in self.actor.modules():
                if isinstance(module, nn.Linear):
                    for param in module.parameters():
                        gradient_estimator.append(torch.zeros_like(param))

        for grad, vec in zip(gradient_estimator, perturb_vec):
            grad += vec * np.sign(2 * prob - 1)

        for _ in range(1):
            self.optim.zero_grad()
            with torch.no_grad():
                for module in self.actor.modules():
                    if isinstance(module, nn.Linear):
                        for param in module.parameters():
                            param.grad = - 1.0 * gradient_estimator.pop(0)

            # nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm)
            self.optim.step()

        self.perturbation_vectors = []
        return


class ES:
    def __init__(self):
        self.actor = Actor(STATE_DIM, ACTION_DIM)  # Actor Network
        self.actor_perturb = Actor(STATE_DIM, ACTION_DIM)
        self.actor_perturb.load_state_dict(self.actor.state_dict())
        self.perturbation_dist = 0.3
        self.perturbation_vectors = []

    def load_model(self, pth):
        self.actor.load_state_dict(torch.load(pth, weights_only=True))
        self.actor_perturb.load_state_dict(self.actor.state_dict())

    def select_action(self, state):
        state = torch.tensor(state, dtype=torch.float32)
        with torch.no_grad():
            prob = self.actor(state)
        action = torch.argmax(prob).item()
        return action

    def perturb_actor(self):
        """
        :param
        :return:
        """
        perturb_vec = []
        self.actor_perturb.load_state_dict(self.actor.state_dict())

        with torch.no_grad():
            for name, module in self.actor_perturb.named_modules():
                if isinstance(module, nn.Linear):
                    for param in module.parameters():
                        noise = torch.randn_like(param) * self.perturbation_dist
                        param.add_(noise)
                        perturb_vec.append(noise)

        self.perturbation_vector = perturb_vec.copy()
        return perturb_vec.copy()

    def train(self, prob):
        """
        :param preference: list of 0-1s
        :param last_layer_only: 1 if only perturb the last layer
        :return:
        """
        if np.mean(prob) > 0.5:
            self.actor.load_state_dict(self.actor_perturb.state_dict())
            self.actor_perturb.load_state_dict(self.actor.state_dict())

        self.perturbation_vectors = []
        return

class PPO:
    def __init__(self):
        self.actor = Actor(STATE_DIM,ACTION_DIM)                       # Actor Network
        self.critic = Critic(STATE_DIM)                                 # Critic Network

        self.lr = 1e-3
        self.gamma = 0.99
        self.gae_lambda = 0.95
        self.clip_eps = 0.2
        self.max_grad_norm = 0.5
        self.entropy_coef = 0.0
        self.batch_size = MAX_STEPS * 2
        self.mini_batch_size = 64
        self.ppo_epochs = 1

        self.actor_optim = optim.AdamW(self.actor.parameters(), lr=self.lr)
        self.critic_optim = optim.AdamW(self.critic.parameters(), lr=self.lr)

    def load_model(self, pth):
        self.actor.load_state_dict(torch.load(pth, weights_only=True))

    def select_action(self, state):
        state = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
        with torch.no_grad():
            prob = self.actor(state)
        dist = Categorical(prob)
        action = dist.sample()
        log_prob = dist.log_prob(action).sum(dim=-1)
        return action.item(), log_prob.item()

    def evaluate(self, states, actions):
        policy = self.actor(states)
        dist = Categorical(policy)
        log_probs = dist.log_prob(actions)
        entropy = dist.entropy()
        values = self.critic(states)
        return log_probs, entropy, values

    def compute_gae(self, rewards, dones, values, next_value):
        advantages = []
        gae = 0
        values = values + [next_value]
        for step in reversed(range(len(rewards))):
            delta = rewards[step] + self.gamma * values[step + 1] * (1 - dones[step]) - values[step]
            gae = delta + self.gamma * self.gae_lambda * (1 - dones[step]) * gae
            advantages.insert(0, gae)
        returns = [adv + val for adv, val in zip(advantages, values[:-1])]
        return advantages, returns

    def train(self, memory):
        batch_size = len(memory['states'])
        states = torch.tensor(np.array(memory['states']), dtype=torch.float32)
        actions = torch.tensor(np.array(memory['actions']), dtype=torch.int64)
        old_log_probs = torch.tensor(np.array(memory['log_probs']), dtype=torch.float32)
        returns = torch.tensor(np.array(memory['returns']), dtype=torch.float32)
        advantages = torch.tensor(np.array(memory['advantages']), dtype=torch.float32)
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

        for _ in range(self.ppo_epochs):
            idxs = np.arange(batch_size)
            np.random.shuffle(idxs)
            for start in range(0, batch_size, self.mini_batch_size):
                end = start + self.mini_batch_size
                mb_idx = idxs[start:end]

                mb_states = states[mb_idx]
                mb_actions = actions[mb_idx]
                mb_old_log_probs = old_log_probs[mb_idx]
                mb_returns = returns[mb_idx]
                mb_advantages = advantages[mb_idx]

                log_probs, entropy, values = self.evaluate(mb_states, mb_actions)

                ratios = torch.exp(log_probs - mb_old_log_probs)
                surr1 = ratios * mb_advantages
                surr2 = torch.clamp(ratios, 1 - self.clip_eps, 1 + self.clip_eps) * mb_advantages
                actor_loss = -torch.min(surr1, surr2).mean() - self.entropy_coef * entropy.mean()

                critic_loss = torch.functional.F.mse_loss(values, mb_returns)

                self.actor_optim.zero_grad()
                actor_loss.backward()
                nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm)
                self.actor_optim.step()

                self.critic_optim.zero_grad()
                critic_loss.backward()
                nn.utils.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm)
                self.critic_optim.step()

        return actor_loss.item(), critic_loss.item()


