import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class MMD_eval(nn.Module):
    def __init__(self, obs_feature_dim, device='cpu') -> None:
        super(MMD_eval, self).__init__()
        self.device = device
        self.fc1 = nn.Linear(obs_feature_dim, 64)
        self.fc2 = nn.Linear(64, 64)
        self.out = nn.Linear(64, 2)
        # orthogonal initialization
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
                nn.init.zeros_(m.bias)

        self.optimizer = torch.optim.Adam(self.parameters(), lr=1e-4)
        self.loss_func = nn.CrossEntropyLoss()

    def forward(self, x):
        x = F.leaky_relu(self.fc1(x))
        x = F.leaky_relu(self.fc2(x))
        o = F.softmax(self.out(x), dim=1)
        return o

    def update(self, obs, features):
        print("MMD_EVAL: training......")
        env_num = len(obs)
        s_obs, t_obs= [], []
        for idx, idx_obs in enumerate(obs):
            if idx < int(env_num / 2):
                s_obs.extend(idx_obs)
            else:
                t_obs.extend(idx_obs)
        s_x = torch.as_tensor(np.concatenate(s_obs, axis=0), dtype=torch.float32, device=self.device)
        t_x = torch.as_tensor(np.concatenate(t_obs, axis=0), dtype=torch.float32, device=self.device)

        s_features, t_features = [], []
        for idx, idx_features in enumerate(features):
            if idx < int(env_num / 2):
                s_features.extend(idx_features)
            else:
                t_features.extend(idx_features)
        s_features = torch.as_tensor(np.concatenate(s_features, axis=0), dtype=torch.float32, device=self.device)
        t_features = torch.as_tensor(np.concatenate(t_features, axis=0), dtype=torch.float32, device=self.device)
        s_x = torch.cat([s_x, s_features], dim=1)
        t_x = torch.cat([t_x, t_features], dim=1)
        
        s_y = torch.eye(2)[0].repeat(s_x.size(0), 1).to(self.device)
        t_y = torch.eye(2)[1].repeat(t_x.size(0), 1).to(self.device)

        x = torch.cat([s_x, t_x], dim=0)
        y = torch.cat([s_y, t_y], dim=0)

        # train
        batch_size = 256
        train_losses, train_accuracies = [], []
        # test_losses, test_accuracies = [], []
        test_num = batch_size
        for _ in range(10):
            idxs = list(range(x.size(0) - test_num))
            np.random.shuffle(idxs)
            batch_id = 0
            while batch_id + batch_size <= x.size(0)-test_num:
                batch_x = x[idxs[batch_id: batch_id+batch_size]]
                batch_o = self(batch_x)
                loss = self.loss_func(batch_o, y[idxs[batch_id: batch_id+batch_size]])
                # optimize
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                # record
                train_losses.append(loss.cpu().item())
                pred_y = torch.max(batch_o, dim=1)[1]
                accuracy = torch.sum(pred_y == torch.max(y[idxs[batch_id: batch_id+batch_size]], dim=1)[1]) / batch_o.size(0)
                train_accuracies.append(accuracy.cpu().item())

                batch_id += batch_size
        test_o = self(x[-test_num:])
        test_loss = self.loss_func(test_o, y[-test_num:])
        test_pred_y = torch.max(test_o, dim=1)[1]
        test_accuracy = torch.sum(test_pred_y == torch.max(y[-test_num:], dim=1)[1]) / test_num

        return np.mean(train_losses).item(), np.mean(train_accuracies).item(), test_loss.cpu().item(), test_accuracy.cpu().item()

class NMMD_eval(MMD_eval):

    def update(self, obs):
        print("NMMD_EVAL: training......")
        env_num = len(obs)
        s_obs, t_obs= [], []
        for idx, idx_obs in enumerate(obs):
            if idx < int(env_num / 2):
                s_obs.extend(idx_obs)
            else:
                t_obs.extend(idx_obs)
        s_x = torch.as_tensor(np.concatenate(s_obs, axis=0), dtype=torch.float32, device=self.device)
        t_x = torch.as_tensor(np.concatenate(t_obs, axis=0), dtype=torch.float32, device=self.device)
        
        s_y = torch.eye(2)[0].repeat(s_x.size(0), 1).to(self.device)
        t_y = torch.eye(2)[1].repeat(t_x.size(0), 1).to(self.device)

        x = torch.cat([s_x, t_x], dim=0)
        y = torch.cat([s_y, t_y], dim=0)

        # train
        batch_size = 256
        train_losses, train_accuracies = [], []
        # test_losses, test_accuracies = [], []
        test_num = batch_size
        for _ in range(10):
            idxs = list(range(x.size(0) - test_num))
            np.random.shuffle(idxs)
            batch_id = 0
            while batch_id + batch_size <= x.size(0)-test_num:
                batch_x = x[idxs[batch_id: batch_id+batch_size]]
                batch_o = self(batch_x)
                loss = self.loss_func(batch_o, y[idxs[batch_id: batch_id+batch_size]])
                # optimize
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                # record
                train_losses.append(loss.cpu().item())
                pred_y = torch.max(batch_o, dim=1)[1]
                accuracy = torch.sum(pred_y == torch.max(y[idxs[batch_id: batch_id+batch_size]], dim=1)[1]) / batch_o.size(0)
                train_accuracies.append(accuracy.cpu().item())

                batch_id += batch_size
        test_o = self(x[-test_num:])
        test_loss = self.loss_func(test_o, y[-test_num:])
        test_pred_y = torch.max(test_o, dim=1)[1]
        test_accuracy = torch.sum(test_pred_y == torch.max(y[-test_num:], dim=1)[1]) / test_num

        return np.mean(train_losses).item(), np.mean(train_accuracies).item(), test_loss.cpu().item(), test_accuracy.cpu().item()


import warnings

import numpy as np
import torch
from numba import njit

from data import Batch, VectorBuffer
from tqdm import *

class MMDCollector(object):
    def __init__(self, policy, env, mmd_eval, nmmd_eval, buffer=None):
        super().__init__()

        self.env = env  # type: ignore
        self.env_num = len(self.env)
        if buffer is None:
            buffer = VectorBuffer(self.env_num, self.env_num)
        self.buffer = buffer
        self.policy = policy
        self._action_space = self.env.action_space
        # avoid creating attribute outside __init__
        self.reset(False)
        self.mmd_eval = mmd_eval
        self.nmmd_eval = nmmd_eval


    def reset(self, reset_buffer = True):
        self.data = Batch(
            obs={}, act={}, rew={}, done={}, obs_next={}, info={}
        )
        self.reset_env()
        if reset_buffer:
            self.reset_buffer()
        self.reset_stat()

    def reset_stat(self):
        self.collect_step, self.collect_episode, self.collect_time = 0, 0, 0.0

    def reset_buffer(self):
        self.buffer.reset()

    def reset_env(self) -> None:
        obs = self.env.reset()
        self.data.obs = obs


    def collect(self, n_episode):
        
        ready_env_ids = np.arange(self.env_num)
        self.data = self.data[:self.env_num]
    
        obs = [[] for _ in range(self.env_num)]
        features = [[] for _ in range(self.env_num)]
        obs_features_lens = 0
        train_mmd_loss, train_mmd_accuracies, test_mmd_loss, test_mmd_accuracies = [], [], [], []
        train_nmmd_loss, train_nmmd_accuracies, test_nmmd_loss, test_nmmd_accuracies = [], [], [], []

        step_count = 0
        episode_count = 0
        episode_rews = np.zeros_like(ready_env_ids, dtype=np.float64) - np.inf
        episode_lens = np.zeros_like(ready_env_ids, dtype=np.float64) - np.inf
 

        while True:
            assert len(self.data) == len(ready_env_ids)

            with torch.no_grad():  
                result = self.policy(self.data)

            act = result.act.detach().cpu().numpy()
            feature = result.feature.detach().cpu().numpy()
            self.data.update(act=act,feature=feature)

            # collect obs and feature for mmd_eval
            for idx, ready_env_id in enumerate(ready_env_ids):
                obs[ready_env_id].append(self.data.obs[idx].copy().reshape(1, -1))
                features[ready_env_id].append(self.data.feature[idx].copy().reshape(1, -1))
            obs_features_lens += len(ready_env_ids)
            if obs_features_lens == 2048:
                # mmd_eval: train
                loss, accuracy, test_loss, test_accuracy = self.mmd_eval.update(obs, features)
                train_mmd_loss.append(loss)
                train_mmd_accuracies.append(accuracy)
                test_mmd_loss.append(test_loss)
                test_mmd_accuracies.append(test_accuracy)
                # nmmd_eval: train
                n_loss, n_accuracy, n_test_loss, n_test_accuracy = self.nmmd_eval.update(obs)
                train_nmmd_loss.append(n_loss)
                train_nmmd_accuracies.append(n_accuracy)
                test_nmmd_loss.append(n_test_loss)
                test_nmmd_accuracies.append(n_test_accuracy)
                # clear buffer
                obs = [[] for _ in range(self.env_num)]
                features = [[] for _ in range(self.env_num)]
                obs_features_lens = 0

            action_remap = self.policy.map_action(self.data.act)
            result = self.env.step(action_remap, ready_env_ids) 
            obs_next, rew, done, info = result

            self.data.update(obs_next=obs_next, rew=rew, done=done, info=info)

            ep_rew, ep_len = self.buffer.add(self.data, ready_env_ids)

            step_count += len(ready_env_ids)

            self.data.obs = self.data.obs_next.copy()
            if np.any(done):
                env_ind_local = np.where(done)[0]
                env_ind_global = ready_env_ids[env_ind_local]
                episode_count += len(env_ind_local)
                episode_rews[env_ind_global] = ep_rew[env_ind_local].copy()
                episode_lens[env_ind_global] = ep_len[env_ind_local].copy()
                obs_reset = self.env.reset(env_ind_global)
                self.data.obs[env_ind_local] = obs_reset

                if n_episode:
                    surplus_env_num = len(ready_env_ids) - (n_episode - episode_count)
                    if surplus_env_num > 0:
                        mask = np.ones_like(ready_env_ids, dtype=bool)
                        mask[env_ind_local[:surplus_env_num]] = False
                        ready_env_ids = ready_env_ids[mask]
                        self.data = self.data[mask]

            if(n_episode and episode_count >= n_episode):
                break

        self.collect_step += step_count
        self.collect_episode += episode_count

        if n_episode:
            self.data = Batch(
                obs={}, act={}, rew={}, done={}, obs_next={}, info={}
            )
            self.reset_env()

        if episode_count > 0:
            rews = np.ma.masked_equal(episode_rews, -np.inf)
            rew_mean, rew_std = rews.mean(), rews.std()
            lens = np.ma.masked_equal(episode_lens, -np.inf)
            len_mean, len_std = lens.mean(), lens.std()
        else:
            rews  = np.array([])
            rew_mean = rew_std  = 0
            lens  = np.array([])
            len_mean = len_std  = 0

        return {
            "n/ep": episode_count,
            "n/st": step_count,
            "rews": rews,
            "rew": rew_mean,
            "rew_std": rew_std,
            "lens": lens,
            "len": len_mean,
            "len_std": len_std,
            "train_mmd_loss": train_mmd_loss,
            "train_mmd_accuracy": train_mmd_accuracies,
            "test_mmd_loss": test_mmd_loss,
            "test_mmd_accuracy": test_mmd_accuracies, 
            "train_nmmd_loss": train_nmmd_loss, 
            "train_nmmd_accuracy": train_nmmd_accuracies,
            "test_nmmd_loss": test_nmmd_loss,
            "test_nmmd_accuracy": test_nmmd_accuracies
        }


#!/usr/bin/env python3

import argparse
import os
import json

import gym
import numpy as np
import torch
from torch import nn
from torch.distributions import Independent, Normal
from torch.optim.lr_scheduler import LambdaLR

from collector import Collector
from env import BaseVectorEnv
from ppo import PPOPolicy
from network import Actor, Critic
from random_env import get_init_params, get_random_params, get_random_params_target, get_random_params2, get_random_params3


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--task', type=str, default='HalfCheetah-v3')
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--buffer-size', type=int, default=4096)
    parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64])
    parser.add_argument('--lr', type=float, default=3e-4)
    parser.add_argument('--gamma', type=float, default=0.99)
    parser.add_argument('--epoch', type=int, default=100)
    parser.add_argument('--step-per-epoch', type=int, default=30000)
    parser.add_argument('--step-per-collect', type=int, default=2048)
    parser.add_argument('--repeat-per-collect', type=int, default=10)
    parser.add_argument('--batch-size', type=int, default=64)
    parser.add_argument('--training-num', type=int, default=64)
    parser.add_argument('--test-num', type=int, default=64)
    # ppo special
    parser.add_argument('--rew-norm', type=int, default=True)
    parser.add_argument('--obs-norm', type=int, default=True)
    # In theory, `vf-coef` will not make any difference if using Adam optimizer.
    parser.add_argument('--vf-coef', type=float, default=0.25)
    parser.add_argument('--ent-coef', type=float, default=0.0)
    parser.add_argument('--gae-lambda', type=float, default=0.95)
    parser.add_argument('--bound-action-method', type=str, default="clip")
    parser.add_argument('--lr-decay', type=int, default=True)
    parser.add_argument('--max-grad-norm', type=float, default=0.5)
    parser.add_argument('--eps-clip', type=float, default=0.2)
    parser.add_argument('--dual-clip', type=float, default=None)
    parser.add_argument('--value-clip', type=int, default=True)
    parser.add_argument('--norm-adv', type=int, default=1)
    parser.add_argument('--recompute-adv', type=int, default=1)
    parser.add_argument('--logdir', type=str, default='log')
    parser.add_argument(
        '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
    )
    parser.add_argument('--resume-path', type=str, default='log_ntl/HalfCheetah-v3/ppo/negative_mmd_loss/seed_0_0801_155813-HalfCheetah_v3_ppo/policy.pth')
    parser.add_argument('--left-bound', type=float, default=1)
    parser.add_argument('--right-bound', type=float, default=2)
    parser.add_argument(
        '--watch',
        default=False,
        action='store_true',
        help='watch the play of pre-trained policy only'
    )
    return parser.parse_args()


def test_ppo(args=get_args()):
    env = gym.make(args.task)
    args.state_shape = env.observation_space.shape or env.observation_space.n
    args.action_shape = env.action_space.shape or env.action_space.n
    args.max_action = env.action_space.high[0]
    print("Observations shape:", args.state_shape)
    print("Actions shape:", args.action_shape)
    print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high))
    
    # seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    # test_envs = gym.make(args.task)
    test_envs = BaseVectorEnv(
        [lambda: gym.make(args.task) for _ in range(args.training_num)], norm_obs=args.obs_norm
    )

    target_env = [i+int(args.test_num/2) for i in range(int(args.test_num/2))]
    init_params = get_init_params(env)
    random_params_list = []
    for i in range(8):
        random_params = get_random_params3(init_params, log_scale_limit = [args.left_bound, args.right_bound])
        test_envs.set_env_attr("body_mass",random_params["body_mass"], target_env[4*i:4*i+4])
        test_envs.set_env_attr("body_inertia",random_params["body_inertia"], target_env[4*i:4*i+4])
        test_envs.set_env_attr("dof_damping",random_params["dof_damping"], target_env[4*i:4*i+4])
        test_envs.set_env_attr("geom_friction",random_params["geom_friction"], target_env[4*i:4*i+4])
        random_params_list.append(random_params)

    # seed
    test_envs.seed(args.seed)

    actor = Actor(args.state_shape[0],args.action_shape[0],device=args.device).to(args.device)
    critic = Critic(args.state_shape[0],device=args.device).to(args.device)

    torch.nn.init.constant_(actor.sigma_param, -0.5)
    for m in list(actor.modules()) + list(critic.modules()):
        if isinstance(m, torch.nn.Linear):
            torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
            torch.nn.init.zeros_(m.bias)
    for m in actor.mu.modules():
        if isinstance(m, torch.nn.Linear):
            torch.nn.init.zeros_(m.bias)
            m.weight.data.copy_(0.01 * m.weight.data)

    optim = torch.optim.Adam(
        list(actor.parameters()) + list(critic.parameters()), lr=args.lr
    )

    lr_scheduler = None
    if args.lr_decay:
        # decay learning rate to 0 linearly
        max_update_num = np.ceil(
            args.step_per_epoch / args.step_per_collect
        ) * args.epoch

        lr_scheduler = LambdaLR(
            optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num
        )

    def dist(*logits):
        return Independent(Normal(*logits), 1)

    policy = PPOPolicy(
        actor,
        critic,
        optim,
        dist,
        discount_factor=args.gamma,
        gae_lambda=args.gae_lambda,
        max_grad_norm=args.max_grad_norm,
        vf_coef=args.vf_coef,
        ent_coef=args.ent_coef,
        reward_normalization=args.rew_norm,
        action_scaling=True,
        action_bound_method=args.bound_action_method,
        lr_scheduler=lr_scheduler,
        action_space=env.action_space,
        eps_clip=args.eps_clip,
        value_clip=args.value_clip,
        dual_clip=args.dual_clip,
        advantage_normalization=args.norm_adv,
        recompute_advantage=args.recompute_adv,
        deterministic_eval=True,
    )

    # load a previous policy
    if args.resume_path:
        policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
        print("Loaded agent from: ", args.resume_path)
        if args.obs_norm:
            p = os.path.join(os.path.split(args.resume_path)[0], 'obs_rms.json')
            print("Loaded obs-norm from: ", p)
            with open(p,'r') as f:
                d = json.load(f)
                mean,var,count = np.array(d['mean']), np.array(d['var']),d['count']
                test_envs.update_obs_rms = False
                test_envs.obs_rms.mean = mean
                test_envs.obs_rms.var = var
                test_envs.obs_rms.count = count

    mmd_eval = MMD_eval(args.state_shape[0]+64, device=args.device).to(args.device)
    nmmd_eval = NMMD_eval(args.state_shape[0], device=args.device).to(args.device)
    test_collector = MMDCollector(policy, test_envs, mmd_eval, nmmd_eval)
 

    # Let's watch its performance!
    policy.eval()
    test_envs.seed(args.seed)
    test_collector.reset()
    result = test_collector.collect(n_episode=args.test_num)
    print(result['rews'],type(result["rews"]))
    print(result['rew'], result['rew_std'])
    print(f'Final reward: {result["rews"].mean()}')

    from torch.utils.tensorboard import SummaryWriter
    writer = SummaryWriter(log_dir=os.path.join(args.resume_path[:-10], 'mmd_eval'), filename_suffix="-mmdeval")
    # train
    for step, (loss, accuracy, n_loss, n_accuracy) in enumerate(
        zip(
            result["train_mmd_loss"], result["train_mmd_accuracy"], result["train_nmmd_loss"], result["train_nmmd_accuracy"]
        )
    ):
        writer.add_scalars(main_tag="mmd_loss", tag_scalar_dict={"train_mmd": loss, "train_nmmd": n_loss}, global_step=step)
        writer.add_scalars(main_tag="accuracy", tag_scalar_dict={"train_mmd": accuracy, "train_nmmd": n_accuracy}, global_step=step)
    # test
    for step, (loss, accuracy, n_loss, n_accuracy) in enumerate(
        zip(
            result["test_mmd_loss"], result["test_mmd_accuracy"], result["test_nmmd_loss"], result["test_nmmd_accuracy"]
        )
    ):
        writer.add_scalars(main_tag="mmd_loss", tag_scalar_dict={"test_mmd": loss, "test_nmmd": n_loss}, global_step=step)
        writer.add_scalars(main_tag="accuracy", tag_scalar_dict={"test_mmd": accuracy, "test_nmmd": n_accuracy}, global_step=step)

if __name__ == '__main__':
    test_ppo()

