# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/td3/#td3_continuous_actionpy
import os
import random
import time
from dataclasses import dataclass
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import tyro
from torch.utils.tensorboard import SummaryWriter
from utils.buffers import ReplayBuffer

from adversary.Adversary import ImagePoison, Discrete, Continuous, exp_cos, Dazer, cos_dist_np, l2dist, log_dist
from adversary.OuterLoop import SleeperNets, Learned_Inception
from adversary.InnerLoop import BadRLMiddleMan, TrojDRLMiddleMan, BadBots, OnCeption
from env.adversarial_mpd import DAZE_Outer
from adversary import patterns
from adversary.Adversary import Null_Action
from utils.models import Agent, Agent_ANN, QNetwork
from utils.utils import Args, make_env, load_dict_from_yaml, AppendWrap


@dataclass
class Args:
    exp_name: str = os.path.basename(__file__)[: -len(".py")]
    """the name of this experiment"""
    seed: int = 1
    """seed of the experiment"""
    torch_deterministic: bool = True
    """if toggled, `torch.backends.cudnn.deterministic=False`"""
    cuda: bool = True
    """if toggled, cuda will be enabled by default"""
    track: bool = False
    """if toggled, this experiment will be tracked with Weights and Biases"""
    wandb_project_name: str = "cleanRL"
    """the wandb's project name"""
    wandb_entity: str = None
    """the entity (team) of wandb's project"""
    capture_video: bool = False
    """whether to capture videos of the agent performances (check out `videos` folder)"""
    save_model: bool = False
    """whether to save model into the `runs/{run_name}` folder"""
    upload_model: bool = False
    """whether to upload the saved model to huggingface"""
    hf_entity: str = ""
    """the user or org name of the model repository from the Hugging Face Hub"""

    # Algorithm specific arguments
    env_id: str = "Hopper-v4"
    """the id of the environment"""
    total_timesteps: int = 1000000
    """total timesteps of the experiments"""
    learning_rate: float = 3e-4
    """the learning rate of the optimizer"""
    num_envs: int = 1
    """the number of parallel game environments"""
    buffer_size: int = int(1e6)
    """the replay memory buffer size"""
    gamma: float = 0.99
    """the discount factor gamma"""
    tau: float = 0.005
    """target smoothing coefficient (default: 0.005)"""
    batch_size: int = 256
    """the batch size of sample from the reply memory"""
    policy_noise: float = 0.2
    """the scale of policy noise"""
    exploration_noise: float = 0.1
    """the scale of exploration noise"""
    learning_starts: int = 25e3
    """timestep to start learning"""
    policy_frequency: int = 2
    """the frequency of training policy (delayed)"""
    noise_clip: float = 0.5
    """noise clip parameter of the Target Policy Smoothing Regularization"""

    config: str = "configs/mujoco.yaml"


def make_env(env_id, seed, idx, capture_video, run_name, daze = False):
    def thunk():
        if capture_video and idx == 0:
            env = gym.make(env_id, render_mode="rgb_array")
            env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
        else:
            env = gym.make(env_id)
        env = gym.wrappers.RecordEpisodeStatistics(env)
        env.action_space.seed(seed)
        if daze:
            print("Append+Null")
            env = AppendWrap(env, 2)
            env = Null_Action(env)
        return env

    return thunk


# ALGO LOGIC: initialize agent here:
class QNetwork(nn.Module):
    def __init__(self, env):
        super().__init__()
        self.fc1 = nn.Linear(
            np.array(env.single_observation_space.shape).prod() + np.prod(env.single_action_space.shape),
            256,
        )
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 1)

    def forward(self, x, a):
        x = torch.cat([x, a], 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


class Actor(nn.Module):
    def __init__(self, env):
        super().__init__()
        self.fc1 = nn.Linear(np.array(env.single_observation_space.shape).prod(), 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc_mu = nn.Linear(256, np.prod(env.single_action_space.shape))
        # action rescaling
        self.register_buffer(
            "action_scale",
            torch.tensor(
                (env.single_action_space.high - env.single_action_space.low) / 2.0,
                dtype=torch.float32,
            ),
        )
        self.register_buffer(
            "action_bias",
            torch.tensor(
                (env.single_action_space.high + env.single_action_space.low) / 2.0,
                dtype=torch.float32,
            ),
        )

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = torch.tanh(self.fc_mu(x))
        return x * self.action_scale + self.action_bias


if __name__ == "__main__":

    args = tyro.cli(Args)
    config = load_dict_from_yaml(args.config)
    args.__dict__.update(config)
    args.target_action = torch.tensor(args.target_action).cuda()
    seeds = [1,2,3,4,5]
    args.p_rate = args.p_rates[0]
    args.unique_id = int(time.time())

    os.makedirs(f"checkpoints/{args.wandb_project_name}/{args.exp_name}/{args.unique_id}", exist_ok = True)
    for seed_index in range(len(seeds)):
        args.seed = seeds[seed_index]
        asr = 0
        
        if args.sn_outer:
                run_name = f"SN_{args.p_rate}_{args.rew_p}_{args.alpha}_{args.clip}"
        elif args.inception:
            run_name = f"QIn_{args.p_rate}_{args.learned}"
        elif args.trojdrl:
            run_name = f"TrojDRL_{args.p_rate}_{args.rew_p}"
        elif args.badrl:
            run_name = f"BadRL_{args.p_rate}_{args.rew_p}"
        elif args.badbots:
            run_name = f"OnCeption_{args.p_rate}_{args.learned}_{args.n_updates}"
        elif args.daze:
            run_name = f"DAZE_{args.p_rate}_{args.num_daze}"
        else:
            run_name = f"Benign"
        run_name += f"_{args.exp_name}"
        run_attack = run_name.split("_")[0]
        args.unique = int(time.time())

        if args.track:
            import wandb

            wandb.init(
                project=args.wandb_project_name,
                entity=args.wandb_entity,
                sync_tensorboard=True,
                config=vars(args),
                name=run_name,
                monitor_gym=True,
                save_code=True,
            )
        writer = SummaryWriter(f"runs/{run_name}")
        writer.add_text(
            "hyperparameters",
            "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
        )

        # TRY NOT TO MODIFY: seeding
        random.seed(args.seed)
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        torch.backends.cudnn.deterministic = args.torch_deterministic

        device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")

        # env setup
        envs = gym.vector.AsyncVectorEnv(
            [make_env(args.env_id, args.seed + i, i, args.capture_video, run_name, args.daze) for i in range(args.num_envs)]
        )

        # --- Setup Daze Attack --- #
        if args.daze:
            poison = patterns.SingleValuePoison(-1, 1)
            poison_batch = patterns.SingleValuePoison(-1, 1)
            dist = lambda x,y: log_dist(x,y, numpy = True) #cos_dist_np if args.dist_type == "cos" else lambda x,y: log_dist(x,y, numpy = True)
            dazer = patterns.SingleValuePoison(-2, 1)
            envs = DAZE_Outer(envs, poison, dazer, dist, args.target_action.cpu().numpy(), args)

        actor = Actor(envs).to(device)
        qf1 = QNetwork(envs).to(device)
        qf2 = QNetwork(envs).to(device)
        qf1_target = QNetwork(envs).to(device)
        qf2_target = QNetwork(envs).to(device)
        target_actor = Actor(envs).to(device)
        target_actor.load_state_dict(actor.state_dict())
        qf1_target.load_state_dict(qf1.state_dict())
        qf2_target.load_state_dict(qf2.state_dict())
        q_optimizer = optim.Adam(list(qf1.parameters()) + list(qf2.parameters()), lr=args.learning_rate)
        actor_optimizer = optim.Adam(list(actor.parameters()), lr=args.learning_rate)

        envs.single_observation_space.dtype = np.float32
        rb = ReplayBuffer(
            args.buffer_size,
            envs.single_observation_space,
            envs.single_action_space,
            device,
            n_envs=args.num_envs,
            handle_timeout_termination=False,
        )
        start_time = time.time()

        # TRY NOT TO MODIFY: start the game
        obs, _ = envs.reset(seed=args.seed)
        for global_step in range(args.total_timesteps):
            # ALGO LOGIC: put action logic here
            if global_step < args.learning_starts:
                actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])
            else:
                with torch.no_grad():
                    actions = actor(torch.Tensor(obs).to(device))
                    actions += torch.normal(0, actor.action_scale * args.exploration_noise)
                    actions = actions.cpu().numpy().clip(envs.single_action_space.low, envs.single_action_space.high)

            # TRY NOT TO MODIFY: execute the game and log data.
            next_obs, rewards, terminations, truncations, infos = envs.step(actions)

            # TRY NOT TO MODIFY: record rewards for plotting purposes
            try:
                next_done = np.logical_or(terminations, truncations)
                sps = int(global_step / (time.time() - start_time))
                if next_done.any():
                    episodic_return = np.mean(infos["episode"]["r"][next_done==1])
                    episodic_length = np.mean(infos["episode"]["l"][next_done==1])
                    
                    print(f"{run_attack}_{args.exp_name} - global_step={global_step}, return={episodic_return}, length={episodic_length}, sps={sps}, asr={asr}            ", end = "\r")
                    writer.add_scalar("charts/episodic_return", episodic_return, global_step)
                    writer.add_scalar("charts/episodic_length", episodic_length, global_step)

                    if args.daze:
                        total_poisoned = infos["poison_stats"][1]
                        num_dazed = infos["poison_stats"][2]
                else:
                    if next_done[0]: print(next_done)
            except:pass
            # if "final_info" in infos:
            #     for info in infos["final_info"]:
            #         if info is not None:
            #             print(f"global_step={global_step}, episodic_return={info['episode']['r']}, SPS={int(global_step / (time.time() - start_time))}, ASR={asr}")
            #             writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
            #             writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)

            #             if args.daze:
            #                 total_poisoned = infos["poison_stats"][1]
            #                 num_dazed = infos["poison_stats"][2]
            #             break

            # TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
            real_next_obs = next_obs.copy()
            for idx, trunc in enumerate(truncations):
                if trunc:
                    #print(infos)
                    real_next_obs[idx] = obs[idx]#infos["final_observation"][idx]
            rb.add(obs, real_next_obs, actions, rewards, terminations, infos)

            # TRY NOT TO MODIFY: CRUCIAL step easy to overlook
            obs = next_obs

            # ALGO LOGIC: training.
            if global_step > args.learning_starts:
                data = rb.sample(args.batch_size)
                with torch.no_grad():
                    clipped_noise = (torch.randn_like(data.actions, device=device) * args.policy_noise).clamp(
                        -args.noise_clip, args.noise_clip
                    ) * target_actor.action_scale

                    next_state_actions = (target_actor(data.next_observations) + clipped_noise).clamp(
                        envs.single_action_space.low[0], envs.single_action_space.high[0]
                    )
                    qf1_next_target = qf1_target(data.next_observations, next_state_actions)
                    qf2_next_target = qf2_target(data.next_observations, next_state_actions)
                    min_qf_next_target = torch.min(qf1_next_target, qf2_next_target)
                    next_q_value = data.rewards.flatten() + (1 - data.dones.flatten()) * args.gamma * (min_qf_next_target).view(-1)

                qf1_a_values = qf1(data.observations, data.actions).view(-1)
                qf2_a_values = qf2(data.observations, data.actions).view(-1)
                qf1_loss = F.mse_loss(qf1_a_values, next_q_value)
                qf2_loss = F.mse_loss(qf2_a_values, next_q_value)
                qf_loss = qf1_loss + qf2_loss

                # optimize the model
                q_optimizer.zero_grad()
                qf_loss.backward()
                q_optimizer.step()

                if global_step % args.policy_frequency == 0:
                    actor_loss = -qf1(data.observations, actor(data.observations)).mean()
                    actor_optimizer.zero_grad()
                    actor_loss.backward()
                    actor_optimizer.step()

                    # update the target network
                    for param, target_param in zip(actor.parameters(), target_actor.parameters()):
                        target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data)
                    for param, target_param in zip(qf1.parameters(), qf1_target.parameters()):
                        target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data)
                    for param, target_param in zip(qf2.parameters(), qf2_target.parameters()):
                        target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data)

                if global_step % 100 == 0:
                    writer.add_scalar("losses/qf1_values", qf1_a_values.mean().item(), global_step)
                    writer.add_scalar("losses/qf2_values", qf2_a_values.mean().item(), global_step)
                    writer.add_scalar("losses/qf1_loss", qf1_loss.item(), global_step)
                    writer.add_scalar("losses/qf2_loss", qf2_loss.item(), global_step)
                    writer.add_scalar("losses/qf_loss", qf_loss.item() / 2.0, global_step)
                    writer.add_scalar("losses/actor_loss", actor_loss.item(), global_step)
                    #print("SPS:", int(global_step / (time.time() - start_time)))
                    writer.add_scalar(
                        "charts/SPS",
                        int(global_step / (time.time() - start_time)),
                        global_step,
                    )

                    # --- Evaluate Attack Success Rate --- #
                    with torch.no_grad():
                        if args.daze:
                            poisoned = poison_batch(data.observations)
                            #mean, _ = agent.get_mean_std(poisoned)
                            mean = actor(poisoned)
                            mean = torch.clamp(mean, -1, 1)
                            attack_err = (log_dist(mean.cpu(), args.target_action.cpu()).mean().item())
                            inf_err = torch.linalg.norm(torch.absolute(mean - args.target_action), ord = torch.inf, dim = 1)
                            asr = torch.mean((inf_err < 0.2).float())
                            inf_err = torch.mean(inf_err)
                            writer.add_scalar("attack_metrics/AttackSuccessRate", asr, global_step)
                            writer.add_scalar("attack_metrics/L-InfError", inf_err, global_step)
                            writer.add_scalar("attack_metrics/AttackError", attack_err, global_step)
                            writer.add_scalar("attack_metrics/daze_rate", num_dazed/(args.num_envs*global_step), global_step)
                            writer.add_scalar("attack_metrics/poisoning_rate", total_poisoned/(8*global_step), global_step)

        # --- Evaluate Final (Attack) Performance --- #
        actor.eval()
        n_eval = args.n_eval
        count = 0
        with torch.no_grad():
            # --- Compute BR Score --- #
            returns = torch.zeros(n_eval)
            obs = []
            
            next_obs, _ = envs.reset(seed=args.seed)
            next_obs = torch.Tensor(next_obs).to(device)
            obs = torch.zeros([n_eval * 1000] + list(next_obs.size())[1:])
            count2 = 0
            indices_eval = np.arange(0, args.num_envs, 1)

            if args.daze:
                envs.p_rate = 0

            print("\nEvaluating Performance")
            while count < n_eval:
                # ALGO LOGIC: action logic
                if count2<len(obs): 
                    obs[count2 : count2+len(next_obs)] = next_obs.cpu()

                count2 += len(next_obs)
                action = actor(next_obs)

                # TRY NOT TO MODIFY: execute the game and log data.
                next_obs, reward, terminations, truncations, infos = envs.step(action.cpu().numpy())
                next_done = np.logical_or(terminations, truncations)
                next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(next_done).to(device)

                try:
                    if args.mujoco and next_done.any():
                        for i in indices_eval[next_done.cpu().numpy()==1]:
                            if count >= n_eval:break
                            returns[count] = torch.tensor(infos['episode']['r'][i])
                            count += 1
                            print(f"Evaluations: {count} / {n_eval}", end = "\r")
                    
                    if "final_info" in infos:
                        print(infos)
                        for info in infos["final_info"]:
                            print(count, info)
                            if count >= n_eval: break
                            if info and "episode" in info[0]:
                                returns[count] = torch.tensor(info[0]['episode']['r'])
                                count += 1
                                print(f"Evaluations: {count} / {n_eval}", end = "\r")
                except: pass

            obs = obs[:count2]
            actions_eval = torch.zeros([len(obs)]+ list(envs.single_action_space.shape))

            # --- Compute ASR Score --- #
            index = 0
            asr = 0; err = 0; inf_ = 0; score = 0
            print()
            if args.inception or args.sn_outer or args.badrl or args.trojdrl or args.daze:
                asr = 0
                while index < len(obs):
                    print(f"Evaluating ASR {index}/{len(obs)}", end = "\r")
                    poisoned = poison_batch(obs[index: index + args.batch_size].to(device))
                    action = actor(poisoned)
                    action = torch.clamp(action, -1, 1)
                    actions_eval[index: index + args.batch_size] = action.cpu()
                    index += args.batch_size

                err = log_dist(actions_eval.cpu(), args.target_action.cpu())
                inf_ = torch.linalg.norm(torch.absolute(actions_eval.cpu()- args.target_action.cpu()), ord = torch.inf, dim = 1)
                asr = (inf_ < 0.2).float()
                err = err.mean().item()
                asr = asr.mean().item()
                inf_ = inf_.mean().item()
            score = returns.mean().item()
            #score_std = returns.std().item()

        # --- Save Model and Experiment Results --- #
        tempid = args.env_id.replace("/", "")
        os.makedirs("results/" + tempid, exist_ok = True)
        save_name = f"{args.seed}_{run_name}_{args.unique}"
        res_done = {"asr": asr, "err": err, "linf": inf_, "return": score}
        print(res_done)
        torch.save(res_done, f"results/{tempid}/{save_name}")

        # Save final model
        torch.save(actor.state_dict(), f"checkpoints/{args.wandb_project_name}/{args.exp_name}/{args.unique_id}/ppo_final_{args.seed}.pt")

        envs.close()
        writer.close()
        if args.track:
            wandb.finish()