# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import argparse
import logging
import os
import pprint
import threading
import time
import timeit
import traceback
import typing
#import wandb

os.environ["OMP_NUM_THREADS"] = "1"  # Necessary for multithreading.

import torch
from torch import multiprocessing as mp
from torch import nn
from torch.nn import functional as F

from torchbeast import atari_wrappers
from torchbeast.core import environment
from torchbeast.core import file_writer
from torchbeast.core import prof
from torchbeast.core import vtrace

import numpy as np


# yapf: disable
parser = argparse.ArgumentParser(description="PyTorch Scalable Agent")

parser.add_argument("--env", type=str, default="PongNoFrameskip-v4",
                    help="Gym environment.")
parser.add_argument("--mode", default="train",
                    choices=["train", "test", "test_render", "eval"],
                    help="Training or test mode.")
parser.add_argument("--xpid", default=None,
                    help="Experiment id (default: None).")
parser.add_argument("--num_episodes", type=int, default=10,
                    help="Num eval episodes.")
parser.add_argument("--pretrained", default=None, type=str,
                    help="Whether or not to use a pretrained featurizer.")

# Training settings.
parser.add_argument("--disable_checkpoint", action="store_true",
                    help="Disable saving checkpoint.")
parser.add_argument("--savedir", default="~/logs/torchbeast",
                    help="Root dir where experiment data will be saved.")
parser.add_argument("--num_actors", default=4, type=int, metavar="N",
                    help="Number of actors (default: 4).")
parser.add_argument("--total_steps", default=100000, type=int, metavar="T",
                    help="Total environment steps to train for.")
parser.add_argument("--batch_size", default=8, type=int, metavar="B",
                    help="Learner batch size.")
parser.add_argument("--unroll_length", default=80, type=int, metavar="T",
                    help="The unroll length (time dimension).")
parser.add_argument("--num_buffers", default=None, type=int,
                    metavar="N", help="Number of shared-memory buffers.")
parser.add_argument("--num_learner_threads", "--num_threads", default=2, type=int,
                    metavar="N", help="Number learner threads.")
parser.add_argument("--disable_cuda", action="store_true",
                    help="Disable CUDA.")
parser.add_argument("--use_lstm", action="store_true",
                    help="Use LSTM in agent model.")
parser.add_argument("--num_layers", default=1, type=int,
                    help="Number hidden layers.")
parser.add_argument("--hidden_size", default=512, type=int,
                    help="Dim of model activations.")

# Loss settings.
parser.add_argument("--entropy_cost", default=0.0006,
                    type=float, help="Entropy cost/multiplier.")
parser.add_argument("--baseline_cost", default=0.5,
                    type=float, help="Baseline cost/multiplier.")
parser.add_argument("--discounting", default=0.99,
                    type=float, help="Discounting factor.")
parser.add_argument("--reward_clipping", default="abs_one",
                    choices=["abs_one", "none"],
                    help="Reward clipping.")

# Optimizer settings.
parser.add_argument("--learning_rate", default=0.00048,
                    type=float, metavar="LR", help="Learning rate.")
parser.add_argument("--alpha", default=0.99, type=float,
                    help="RMSProp smoothing constant.")
parser.add_argument("--momentum", default=0, type=float,
                    help="RMSProp momentum.")
parser.add_argument("--epsilon", default=0.0001, type=float,
                    help="RMSProp epsilon.")
parser.add_argument("--beta1", default=0.9, type=float,
                    help="Adam beta1.")
parser.add_argument("--beta2", default=0.999, type=float,
                    help="Adam beta2.")
parser.add_argument("--grad_norm_clipping", default=40.0, type=float,
                    help="Global gradient norm clip.")
# yapf: enable

# Proxy settings
parser.add_argument("--fuel_multiplier", default=1.0, type=float,
                    help="How much to increase the score of shooting a fuel cannister to be.")
parser.add_argument("--move_penalty", default=0.0, type=float,
                    help="How much to penalize moving")
parser.add_argument("--true_move_penalty", default=0.0, type=float,
                    help="How much to truely penalize moving")

logging.basicConfig(
    format=(
        "[%(levelname)s:%(process)d %(module)s:%(lineno)d %(asctime)s] " "%(message)s"
    ),
    level=0,
)

Buffers = typing.Dict[str, typing.List[torch.Tensor]]


def compute_baseline_loss(advantages):
    return 0.5 * torch.sum(advantages ** 2)


def compute_entropy_loss(logits):
    """Return the entropy loss, i.e., the negative entropy of the policy."""
    policy = F.softmax(logits, dim=-1)
    log_policy = F.log_softmax(logits, dim=-1)
    return torch.sum(policy * log_policy)


def compute_policy_gradient_loss(logits, actions, advantages):
    cross_entropy = F.nll_loss(
        F.log_softmax(torch.flatten(logits, 0, 1), dim=-1),
        target=torch.flatten(actions, 0, 1),
        reduction="none",
    )
    cross_entropy = cross_entropy.view_as(advantages)
    return torch.sum(cross_entropy * advantages.detach())


def act(
    flags,
    actor_index: int,
    free_queue: mp.SimpleQueue,
    full_queue: mp.SimpleQueue,
    model: torch.nn.Module,
    buffers: Buffers,
    initial_agent_state_buffers,
):
    try:
        logging.info("Actor %i started.", actor_index)
        timings = prof.Timings()  # Keep track of how fast things are.

        if flags.pretrained:
            gym_env = create_env_seaquest(flags)
        else:
            gym_env = create_env(flags)
        seed = actor_index ^ int.from_bytes(os.urandom(4), byteorder="little")
        gym_env.seed(seed)
        env = environment.Environment(gym_env)
        env_output = env.initial()
        agent_state = model.initial_state(batch_size=1)
        agent_output, unused_state = model(env_output, agent_state)
        while True:
            index = free_queue.get()
            if index is None:
                break

            # Write old rollout end.
            for key in env_output:
                buffers[key][index][0, ...] = env_output[key]
            for key in agent_output:
                buffers[key][index][0, ...] = agent_output[key]
            for i, tensor in enumerate(agent_state):
                initial_agent_state_buffers[index][i][...] = tensor

            # Do new rollout.
            for t in range(flags.unroll_length):
                timings.reset()

                with torch.no_grad():
                    agent_output, agent_state = model(env_output, agent_state)

                timings.time("model")

                env_output = env.step(agent_output["action"])

                timings.time("step")

                for key in env_output:
                    buffers[key][index][t + 1, ...] = env_output[key]
                for key in agent_output:
                    buffers[key][index][t + 1, ...] = agent_output[key]

                timings.time("write")
            full_queue.put(index)

        if actor_index == 0:
            logging.info("Actor %i: %s", actor_index, timings.summary())

    except KeyboardInterrupt:
        pass  # Return silently.
    except Exception as e:
        logging.error("Exception in worker process %i", actor_index)
        traceback.print_exc()
        print()
        raise e


def get_batch(
    flags,
    free_queue: mp.SimpleQueue,
    full_queue: mp.SimpleQueue,
    buffers: Buffers,
    initial_agent_state_buffers,
    timings,
    lock=threading.Lock(),
):
    with lock:
        timings.time("lock")
        indices = [full_queue.get() for _ in range(flags.batch_size)]
        timings.time("dequeue")
    batch = {
        key: torch.stack([buffers[key][m] for m in indices], dim=1) for key in buffers
    }
    initial_agent_state = (
        torch.cat(ts, dim=1)
        for ts in zip(*[initial_agent_state_buffers[m] for m in indices])
    )
    timings.time("batch")
    for m in indices:
        free_queue.put(m)
    timings.time("enqueue")
    batch = {k: t.to(device=flags.device, non_blocking=True) for k, t in batch.items()}
    initial_agent_state = tuple(
        t.to(device=flags.device, non_blocking=True) for t in initial_agent_state
    )
    timings.time("device")
    return batch, initial_agent_state


def learn(
    flags,
    actor_model,
    model,
    batch,
    initial_agent_state,
    optimizer,
    scheduler,
    lock=threading.Lock(),  # noqa: B008
):
    """Performs a learning (optimization) step."""
    with lock:
        learner_outputs, unused_state = model(batch, initial_agent_state)

        # Take final value function slice for bootstrapping.
        bootstrap_value = learner_outputs["baseline"][-1]

        # Move from obs[t] -> action[t] to action[t] -> obs[t].
        batch = {key: tensor[1:] for key, tensor in batch.items()}
        learner_outputs = {key: tensor[:-1] for key, tensor in learner_outputs.items()}

        rewards = batch["reward"]
        if flags.reward_clipping == "abs_one":
            clipped_rewards = torch.clamp(rewards, -1, 1)
        elif flags.reward_clipping == "none":
            clipped_rewards = rewards

        discounts = (~batch["done"]).float() * flags.discounting

        vtrace_returns = vtrace.from_logits(
            behavior_policy_logits=batch["policy_logits"],
            target_policy_logits=learner_outputs["policy_logits"],
            actions=batch["action"],
            discounts=discounts,
            rewards=clipped_rewards,
            values=learner_outputs["baseline"],
            bootstrap_value=bootstrap_value,
        )

        pg_loss = compute_policy_gradient_loss(
            learner_outputs["policy_logits"],
            batch["action"],
            vtrace_returns.pg_advantages,
        )
        baseline_loss = flags.baseline_cost * compute_baseline_loss(
            vtrace_returns.vs - learner_outputs["baseline"]
        )
        entropy_loss = flags.entropy_cost * compute_entropy_loss(
            learner_outputs["policy_logits"]
        )

        total_loss = pg_loss + baseline_loss + entropy_loss

        episode_returns = batch["episode_return"][batch["done"]]
        episode_true_returns = batch["episode_true_return"][batch["done"]]
        episode_true_move = batch["episode_true_move"][batch["done"]]
        stats = {
            "episode_returns": tuple(episode_returns.cpu().numpy()),
            "episode_true_returns": tuple(episode_true_returns.cpu().numpy()),
            "episode_true_move": tuple(episode_true_move.cpu().numpy()),
            "mean_episode_return": torch.mean(episode_returns).item(),
            "mean_episode_true_return": torch.mean(episode_true_returns).item(),
            "mean_episode_true_move": torch.mean(episode_true_move).item(),
            "total_loss": total_loss.item(),
            "pg_loss": pg_loss.item(),
            "baseline_loss": baseline_loss.item(),
            "entropy_loss": entropy_loss.item(),
        }

        optimizer.zero_grad()
        total_loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), flags.grad_norm_clipping)
        optimizer.step()
        scheduler.step()

        actor_model.load_state_dict(model.state_dict())
        return stats


def create_buffers(flags, obs_shape, num_actions) -> Buffers:
    T = flags.unroll_length
    specs = dict(
        frame=dict(size=(T + 1, *obs_shape), dtype=torch.uint8),
        reward=dict(size=(T + 1,), dtype=torch.float32),
        true_reward=dict(size=(T + 1,), dtype=torch.float32),
        true_move=dict(size=(T + 1,), dtype=torch.float32),
        done=dict(size=(T + 1,), dtype=torch.bool),
        episode_return=dict(size=(T + 1,), dtype=torch.float32),
        episode_true_return=dict(size=(T + 1,), dtype=torch.float32),
        episode_true_move=dict(size=(T + 1,), dtype=torch.float32),
        episode_step=dict(size=(T + 1,), dtype=torch.int32),
        policy_logits=dict(size=(T + 1, num_actions), dtype=torch.float32),
        baseline=dict(size=(T + 1,), dtype=torch.float32),
        last_action=dict(size=(T + 1,), dtype=torch.int64),
        action=dict(size=(T + 1,), dtype=torch.int64),
    )
    buffers: Buffers = {key: [] for key in specs}
    for _ in range(flags.num_buffers):
        for key in buffers:
            buffers[key].append(torch.empty(**specs[key]).share_memory_())
    return buffers


def train(flags):  # pylint: disable=too-many-branches, too-many-statements
    if flags.xpid is None:
        flags.xpid = "torchbeast-%s" % time.strftime("%Y%m%d-%H%M%S")
    plogger = file_writer.FileWriter(
        xpid=flags.xpid, xp_args=flags.__dict__, rootdir=flags.savedir
    )
    checkpointpath = os.path.expandvars(
        os.path.expanduser("%s/%s/%s" % (flags.savedir, flags.xpid, "model.tar"))
    )

    if flags.num_buffers is None:  # Set sensible default for num_buffers.
        flags.num_buffers = max(2 * flags.num_actors, flags.batch_size)
    if flags.num_actors >= flags.num_buffers:
        raise ValueError("num_buffers should be larger than num_actors")
    if flags.num_buffers < flags.batch_size:
        raise ValueError("num_buffers should be larger than batch_size")

    T = flags.unroll_length
    B = flags.batch_size

    flags.device = None
    if not flags.disable_cuda and torch.cuda.is_available():
        logging.info("Using CUDA.")
        flags.device = torch.device("cuda")
    else:
        logging.info("Not using CUDA.")
        flags.device = torch.device("cpu")

    if flags.pretrained:
        env = create_env_seaquest(flags)
    else:
        env = create_env(flags)

    if flags.pretrained:
        model = FNet(env.observation_space.shape, env.action_space.n, num_layers=flags.num_layers, hidden_size=flags.hidden_size, use_lstm=flags.use_lstm)
        model.load_my_state_dict(torch.load(flags.pretrained))
    else:
        model = Net(env.observation_space.shape, env.action_space.n, num_layers=flags.num_layers, hidden_size=flags.hidden_size, use_lstm=flags.use_lstm)


    buffers = create_buffers(flags, env.observation_space.shape, model.num_actions)

    model.share_memory()

    # Add initial RNN state.
    initial_agent_state_buffers = []
    for _ in range(flags.num_buffers):
        state = model.initial_state(batch_size=1)
        for t in state:
            t.share_memory_()
        initial_agent_state_buffers.append(state)

    actor_processes = []
    ctx = mp.get_context("fork")
    free_queue = ctx.SimpleQueue()
    full_queue = ctx.SimpleQueue()

    for i in range(flags.num_actors):
        actor = ctx.Process(
            target=act,
            args=(
                flags,
                i,
                free_queue,
                full_queue,
                model,
                buffers,
                initial_agent_state_buffers,
            ),
        )
        actor.start()
        actor_processes.append(actor)

    if flags.pretrained:
        learner_model = FNet(
            env.observation_space.shape, env.action_space.n, num_layers=flags.num_layers, hidden_size=flags.hidden_size, use_lstm=flags.use_lstm
        ).to(device=flags.device)
        learner_model.load_my_state_dict(torch.load(flags.pretrained))
    else:
        learner_model = Net(
            env.observation_space.shape, env.action_space.n, num_layers=flags.num_layers, hidden_size=flags.hidden_size, use_lstm=flags.use_lstm
        ).to(device=flags.device)

    optimizer = torch.optim.RMSprop(
        learner_model.parameters(),
        lr=flags.learning_rate,
        momentum=flags.momentum,
        eps=flags.epsilon,
        alpha=flags.alpha,
    )

    # optimizer = torch.optim.Adam(
    #     learner_model.parameters(),
    #     lr=flags.learning_rate,
    # )

    def lr_lambda(epoch):
        return 1 - min(epoch * T * B, flags.total_steps) / flags.total_steps

    def drop_lambda(epoch):
        left = 1 - min(epoch * T * B, flags.total_steps) / flags.total_steps
        return 2 ** (-5 + (left * 10) // 2)

    def staggered_drop_lambda(epoch):
        TOTAL = 100_000_000
        left = 1 - min(epoch * T * B, TOTAL) / TOTAL
        if left > 0.9:
            return 2 ** -1
        if left > 0.7:
            return 2 ** -2
        if left > 0.5:
            return 2 ** -3
        if left > 0.3:
            return 2 ** -4
        return 2 ** -5 * (left/0.3)

    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, staggered_drop_lambda)

    logger = logging.getLogger("logfile")
    stat_keys = [
        "total_loss",
        "mean_episode_return",
        "mean_episode_true_return",
        "mean_episode_true_move",
        "pg_loss",
        "baseline_loss",
        "entropy_loss",
    ]
    logger.info("# Step\t%s", "\t".join(stat_keys))

    step, stats = 0, {}

    def batch_and_learn(i, lock=threading.Lock()):
        """Thread target for the learning process."""
        nonlocal step, stats
        timings = prof.Timings()
        while step < flags.total_steps:
            timings.reset()
            batch, agent_state = get_batch(
                flags,
                free_queue,
                full_queue,
                buffers,
                initial_agent_state_buffers,
                timings,
            )
            stats = learn(
                flags, model, learner_model, batch, agent_state, optimizer, scheduler
            )
            timings.time("learn")
            with lock:
                to_log = dict(step=step)
                to_log.update({k: stats[k] for k in stat_keys})
                plogger.log(to_log)
                step += T * B

        if i == 0:
            logging.info("Batch and learn: %s", timings.summary())

    for m in range(flags.num_buffers):
        free_queue.put(m)

    threads = []
    for i in range(flags.num_learner_threads):
        thread = threading.Thread(
            target=batch_and_learn, name="batch-and-learn-%d" % i, args=(i,)
        )
        thread.start()
        threads.append(thread)

    def checkpoint():
        if flags.disable_checkpoint:
            return
        logging.info("Saving checkpoint to %s", checkpointpath)
        torch.save(
            {
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "scheduler_state_dict": scheduler.state_dict(),
                "flags": vars(flags),
            },
            checkpointpath,
        )

    #wandb.init(project="test-space", entity="aypan17", group="atari")
    timer = timeit.default_timer
    try:
        last_checkpoint_time = timer()
        while step < flags.total_steps:
            start_step = step
            start_time = timer()
            time.sleep(5)

            if timer() - last_checkpoint_time > 10 * 60:  # Save every 10 min.
                checkpoint()
                last_checkpoint_time = timer()

            sps = (step - start_step) / (timer() - start_time)
            if stats.get("episode_returns", None):
                mean_return = (
                    "Return per episode: %.1f. " % stats["mean_episode_return"]
                )
            else:
                mean_return = ""
            if stats.get("episode_true_returns", None):
                mean_true_return = (
                    "True return per episode: %.1f. " % stats["mean_episode_true_return"]
                )
            else:
                mean_true_return = ""
            if stats.get("episode_true_move", None):
                mean_true_move = (
                    "True move per episode: %.1f. " % stats["mean_episode_true_move"]
                )
            else:
                mean_true_move = ""
            total_loss = stats.get("total_loss", float("inf"))
            #if stats.get("episode_returns", None) and stats.get("true_episode_returns", None):
                #wandb.log({"loss":total_loss, "episode_return": stats["mean_episode_return"], "true_episode_return": stats["mean_episode_true_return"]})
            logging.info(
                "Steps %i @ %.1f SPS. Loss %f. %s%s%sStats:\n%s",
                step,
                sps,
                total_loss,
                mean_return,
                mean_true_return,
                mean_true_move,
                pprint.pformat(stats),
            )
    except KeyboardInterrupt:
        return  # Try joining actors then quit.
    else:
        for thread in threads:
            thread.join()
        logging.info("Learning finished after %d steps.", step)
    finally:
        for _ in range(flags.num_actors):
            free_queue.put(None)
        for actor in actor_processes:
            actor.join(timeout=1)

    checkpoint()
    plogger.close()


def evaluate(flags, num_episodes: int = 4):
    for folder in os.listdir(flags.savedir):
        if not os.path.isdir(os.path.join(flags.savedir, folder)):
            continue
        params = folder.split("_")
        if ("staggered" not in params and "pacifist" not in params) or len(params) < 3:
            continue
        test(flags, num_episodes=num_episodes, folder=folder, num_layers=int(params[-1]), hidden_size=int(params[-2]))


def test(flags, num_episodes: int = 4, folder = None, num_layers = None, hidden_size = None):
    if folder is not None:
        assert num_layers is not None 
        assert hidden_size is not None
    elif flags.xpid is None:
        assert False
        checkpointpath = "./latest/model.tar"
    else:
        folder = flags.xpid
        num_layers = flags.num_layers
        hidden_size = flags.hidden_size

    checkpointpath = os.path.expandvars(
            os.path.expanduser("%s/%s/%s" % (flags.savedir, folder, "model.tar"))
        )

    if flags.pretrained:
        gym_env = create_env_seaquest(flags)
        env = environment.Environment(gym_env)
        model = FNet(gym_env.observation_space.shape, gym_env.action_space.n, num_layers=num_layers, hidden_size=hidden_size, use_lstm=flags.use_lstm)
    else:
        gym_env = create_env(flags)
        env = environment.Environment(gym_env)
        model = Net(gym_env.observation_space.shape, gym_env.action_space.n, num_layers=num_layers, hidden_size=hidden_size, use_lstm=flags.use_lstm)

    env._max_episode_steps = 10

    model.eval()
    p = sum(p.numel() for p in model.fc.parameters())# if p.requires_grad)
    #print([p.size() for p in model.fc.parameters()])
    #print(model.fc)
    #assert False
    try:
        checkpoint = torch.load(checkpointpath, map_location="cpu")
        model.load_state_dict(checkpoint["model_state_dict"])
    except:
        print(f"Model of width {hidden_size} and depth {num_layers} failed to load")
        return 

    returns = []
    true_returns = []
    true_move = []
    lens = []

    for i in range(15):
        tmp_ret = []
        tmp_true = []
        tmp_move = []
        tmp_lens = []
        observation = env.initial()
        agent_state = model.initial_state(batch_size=1)
        while len(tmp_ret) < num_episodes:
            if flags.mode == "test_render":
                env.gym_env.render()
            agent_outputs = model(observation, agent_state)
            policy_outputs, agent_state = agent_outputs
            observation = env.step(policy_outputs["action"])
            if observation["done"].item():
                tmp_ret.append(observation["episode_return"].item())
                tmp_true.append(observation["episode_true_return"].item())
                tmp_move.append(observation["episode_true_move"].item())
                tmp_lens.append(observation["episode_step"].item())
                # logging.info(
                #     "Episode ended after %d steps. Return: %.1f. True return: %.1f. True move: %.1f. Return per step: %.1f. True return per step: %.1f",
                #     observation["episode_step"].item(),
                #     observation["episode_return"].item(),
                #     observation["episode_true_return"].item(),
                #     observation["episode_true_move"].item(),
                #     observation["episode_return"].item() / observation["episode_step"].item(),
                #     observation["episode_true_return"].item() / observation["episode_step"].item()
                # )
        logging.info(str(tmp_lens))
        returns.append(sum(tmp_ret))
        true_returns.append(sum(tmp_true))
        true_move.append(sum(tmp_move))
        lens.append(sum(tmp_lens))
    env.close()
    lens = np.array(lens)
    returns = np.array(returns)
    true_returns = np.array(true_returns)
    true_move = np.array(true_move)
    return_per_step = returns / lens
    true_return_per_step = true_returns / lens
    logging.info(
        "Average returns over %i episodes: %.1f +/- %.1f", num_episodes, np.mean(returns).item(), np.std(returns).item()
    )
    logging.info(
        "Average true returns over %i episodes: %.1f +/- %.1f", num_episodes, np.mean(true_returns).item(), np.std(true_returns).item()
    )
    logging.info(
        "Average true move over %i episodes: %.1f +/- %.1f", num_episodes, np.mean(true_move).item(), np.std(true_move).item()
    )
    logging.info(
        "Average return/step over %i episodes: %.1f +/- %.1f", num_episodes, np.mean(return_per_step).item(), np.std(return_per_step).item()
    )
    logging.info(
        "Average true return/step over %i episodes: %.1f +/- %.1f", num_episodes, np.mean(true_return_per_step).item(), np.std(true_return_per_step).item()
    )
    logging.info(
        "Average num steps over %i episodes: %.1f +/- %.1f", num_episodes, np.mean(lens).item(), np.std(lens).item()
    )
    f = open(os.path.join(flags.savedir, folder+"_"+flags.xpid+".json"), "w")
    json.dump({'params': [p], 'rew': returns.tolist(), 'true_rew': true_returns.tolist(), 'true_move': true_move.tolist(),
                'rew_step': return_per_step.tolist(), 'true_rew_step': true_return_per_step.tolist(), 'len':lens.tolist()}, f)
    f.close()

class FeaturizedAtariNet(nn.Module):
    def __init__(self, observation_shape, num_actions, num_layers=1, hidden_size=512, use_lstm=False):
        super(FeaturizedAtariNet, self).__init__()
        self.observation_shape = observation_shape
        self.num_actions = num_actions

        # Feature extraction
        self.conv1 = nn.Conv2d(self.observation_shape[0], 32, 5, stride=1, padding=2)
        self.maxp1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(32, 32, 5, stride=1, padding=1)
        self.maxp2 = nn.MaxPool2d(2, 2)
        self.conv3 = nn.Conv2d(32, 64, 4, stride=1, padding=1)
        self.maxp3 = nn.MaxPool2d(2, 2)
        self.conv4 = nn.Conv2d(64, 64, 3, stride=1, padding=1)
        self.maxp4 = nn.MaxPool2d(2, 2)

        # Fully connected layer.
        self.fc = nn.ModuleList([nn.Linear(1024, hidden_size)] + [nn.Linear(hidden_size, hidden_size)] * (num_layers-1))

        # FC output size + one-hot of last action + last reward.
        core_output_size = hidden_size + num_actions + 1

        self.policy = nn.Linear(core_output_size, self.num_actions)
        self.baseline = nn.Linear(core_output_size, 1)

    def initial_state(self, batch_size):
        return tuple()

    def forward(self, inputs, core_state=()):
        x = inputs["frame"]  # [T, B, C, H, W].
        T, B, *_ = x.shape
        x = torch.flatten(x, 0, 1).float()  # Merge time and batch.
        x = F.relu(self.maxp1(self.conv1(x)))
        x = F.relu(self.maxp2(self.conv2(x)))
        x = F.relu(self.maxp3(self.conv3(x)))
        x = F.relu(self.maxp4(self.conv4(x)))
        x = x.view(T * B, -1)
        for layer in self.fc:
            x = F.relu(layer(x))

        one_hot_last_action = F.one_hot(
            inputs["last_action"].view(T * B), self.num_actions
        ).float()
        clipped_reward = torch.clamp(inputs["reward"], -1, 1).view(T * B, 1)
        core_input = torch.cat([x, clipped_reward, one_hot_last_action], dim=-1)

        core_output = core_input
        core_state = tuple()

        policy_logits = self.policy(core_output)
        baseline = self.baseline(core_output)

        if self.training:
            action = torch.multinomial(F.softmax(policy_logits, dim=1), num_samples=1)
        else:
            # Don't sample when testing.
            action = torch.argmax(policy_logits, dim=1)

        policy_logits = policy_logits.view(T, B, self.num_actions)
        baseline = baseline.view(T, B)
        action = action.view(T, B)

        return (
            dict(policy_logits=policy_logits, baseline=baseline, action=action),
            core_state,
        )

    def load_my_state_dict(self, state_dict):
 
        own_state = self.state_dict()
        for name, param in state_dict.items():
            if name not in own_state:
                 continue
            own_state[name].copy_(param)


class AtariNet(nn.Module):
    def __init__(self, observation_shape, num_actions, num_layers=1, hidden_size=512, use_lstm=False):
        super(AtariNet, self).__init__()
        self.observation_shape = observation_shape
        self.num_actions = num_actions

        # Feature extraction.
        self.conv1 = nn.Conv2d(
            in_channels=self.observation_shape[0],
            out_channels=32,
            kernel_size=8,
            stride=4,
        )
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)

        # Fully connected layer.
        self.fc = nn.ModuleList([nn.Linear(3136, hidden_size)] + [nn.Linear(hidden_size, hidden_size) for _ in range(num_layers-1)])

        # FC output size + one-hot of last action + last reward.
        core_output_size = hidden_size + num_actions + 1

        self.use_lstm = use_lstm
        if use_lstm:
            self.core = nn.LSTM(core_output_size, core_output_size, 2)

        self.policy = nn.Linear(core_output_size, self.num_actions)
        self.baseline = nn.Linear(core_output_size, 1)

    def initial_state(self, batch_size):
        if not self.use_lstm:
            return tuple()
        return tuple(
            torch.zeros(self.core.num_layers, batch_size, self.core.hidden_size)
            for _ in range(2)
        )

    def forward(self, inputs, core_state=()):
        x = inputs["frame"]  # [T, B, C, H, W].
        T, B, *_ = x.shape
        x = torch.flatten(x, 0, 1)  # Merge time and batch.
        x = x.float() / 255.0
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = x.view(T * B, -1)
        for layer in self.fc:
            x = F.relu(layer(x))

        one_hot_last_action = F.one_hot(
            inputs["last_action"].view(T * B), self.num_actions
        ).float()
        clipped_reward = torch.clamp(inputs["reward"], -1, 1).view(T * B, 1)
        core_input = torch.cat([x, clipped_reward, one_hot_last_action], dim=-1)

        if self.use_lstm:
            core_input = core_input.view(T, B, -1)
            core_output_list = []
            notdone = (~inputs["done"]).float()
            for input, nd in zip(core_input.unbind(), notdone.unbind()):
                # Reset core state to zero whenever an episode ended.
                # Make `done` broadcastable with (num_layers, B, hidden_size)
                # states:
                nd = nd.view(1, -1, 1)
                core_state = tuple(nd * s for s in core_state)
                output, core_state = self.core(input.unsqueeze(0), core_state)
                core_output_list.append(output)
            core_output = torch.flatten(torch.cat(core_output_list), 0, 1)
        else:
            core_output = core_input
            core_state = tuple()

        policy_logits = self.policy(core_output)
        baseline = self.baseline(core_output)

        if self.training:
            action = torch.multinomial(F.softmax(policy_logits, dim=1), num_samples=1)
        else:
            # Don't sample when testing.
            action = torch.argmax(policy_logits, dim=1)

        policy_logits = policy_logits.view(T, B, self.num_actions)
        baseline = baseline.view(T, B)
        action = action.view(T, B)

        return (
            dict(policy_logits=policy_logits, baseline=baseline, action=action),
            core_state,
        )


Net = AtariNet
FNet = FeaturizedAtariNet

def create_env(flags):
    return atari_wrappers.wrap_pytorch(
        atari_wrappers.wrap_deepmind(
            atari_wrappers.make_atari(flags.env),
            clip_rewards=False,
            frame_stack=True,
            scale=False,
            fuel_multiplier=flags.fuel_multiplier,
            move_penalty=flags.move_penalty,
            true_move_penalty=flags.true_move_penalty,
        )
    )

def create_env_seaquest(flags):
    return atari_wrappers.wrap_pytorch(
        atari_wrappers.wrap_deepmind(
            atari_wrappers.make_atari(flags.env),
            clip_rewards=False,
            frame_stack=False,
            scale=False,
            fuel_multiplier=flags.fuel_multiplier,
            move_penalty=flags.move_penalty,
            true_move_penalty=flags.true_move_penalty,
        )
    )


def main(flags):
    if flags.mode == "train":
        train(flags)
    elif flags.mode == "eval":
        evaluate(flags, num_episodes=flags.num_episodes)
    else:
        test(flags, num_episodes=flags.num_episodes)


if __name__ == "__main__":
    flags = parser.parse_args()
    main(flags)
