import time
import datetime

from numpy.core.fromnumeric import shape
from utils.format import default_preprocess_obss
import torch
import tensorboardX
import sys

import gym
import gc
import numpy as np 

import utils
from model import ACModel
from agent import Agent
from env_rover import RoverEnv
from dst import DeepSea
from mj_cost_wrapper import MJCostWrapper

MUJOCO_ENVS = ["HalfCheetah-v3", "Ant-v3", "Hopper-v3", "Humanoid-v3", "Walker2d-v3"]
MUJOCO_ENVS += ["HalfCheetah-v2", "Ant-v2", "Hopper-v2", "Humanoid-v2", "Walker2d-v2"]

class PrefDistCnt():
    def __init__(self, omega) -> None:
        self.omega = omega 
    def sample(self):
        return self.omega
    
class PrefDistUniform():
    def __init__(self, reward_dim) -> None:
        self.reward_dim = reward_dim 
    def sample(self):
        sample = np.zeros(self.reward_dim)
        sample[:-1] = np.random.rand(self.reward_dim-1)
        sample[-1] = np.sum(sample[:-1])
        return sample.tolist()

def run(args):
    # Set run dir
    model_name = f"{args.env}"
    if args.epsilon is not None:
        eps_name = ""
        for e in args.epsilon:
            eps_name += str(e)
        model_name += f"_epsilon_{eps_name}"
    if args.nA:
        model_name += "_nA"
    if args.ud:
        model_name += "_ud"
    model_name = model_name + f"_seed_{args.seed}"
    model_dir = utils.get_model_dir(model_name, args.s)

    # Load loggers and Tensorboard writer
    txt_logger = utils.get_txt_logger(model_dir)
    csv_file, csv_logger = utils.get_csv_logger(model_dir, "log.csv")
    tb_writer = tensorboardX.SummaryWriter(model_dir)

    # Log command and all script arguments
    txt_logger.info("{}\n".format(" ".join(sys.argv)))
    txt_logger.info("{}\n".format(args))

    # Set seed for all randomness sources
    utils.seed(args.seed)

    # Set device
    if torch.cuda.is_available():
        device = torch.device("cuda:" + str(args.device))
    else:
        device = torch.device("cpu")

    txt_logger.info(f"Device: {device}\n")

    # Load environments
    envs = []
    for _ in range(args.procs):
        if args.env == "Rover":
            env = RoverEnv()
            env.seed(args.seed)
            action_dim = 1
            reward_dim = 2
            is_continous = False 
            include_conv = False
        elif args.env == "DST":
            env = DeepSea()
            env.seed(args.seed)
            action_dim = 1
            reward_dim = 2
            is_continous = False 
            include_conv = False
        elif args.env in MUJOCO_ENVS:
            env = gym.make(args.env)
            env = MJCostWrapper(env)
            env.seed(args.seed)
            action_dim = env.action_space.shape[0]
            reward_dim = 1
            is_continous = True
            include_conv = False 
        else: 
            raise ValueError("Invalid env!")
        envs.append(env)
    txt_logger.info("Environments loaded\n")

    # Load training status
    # try:
    #     status = utils.get_status(model_dir)
    #     txt_logger.info("Training status loaded\n")
    # except OSError:
    status = {"num_frames": 0,"update": 0}
    txt_logger.info("No training status found - creating new\n")

    # Load observations preprocessor
    obs_space = envs[0].observation_space
    action_space = envs[0].action_space
    preprocess_obs = default_preprocess_obss
    txt_logger.info("Observations preprocessor loaded\n")

    # Create agent AC model
    model = ACModel(obs_space=obs_space, 
                    action_space=action_space, 
                    reward_dim=reward_dim,
                    is_continous=is_continous, 
                    include_conv=include_conv, 
                    device=device)
    model.to(device)
    txt_logger.info("Agent model created\n")

    # Try to load them if saved previously
    if "model_state" in status:
        model.load_state_dict(status["model_state"])
        txt_logger.info("Agent Model loaded\n")
        txt_logger.info("{}\n".format(model))

    d_omega = PrefDistUniform(reward_dim) if args.ud else None 
    # Create agent
    agent = Agent(envs, 
                model=model,
                reward_dim=reward_dim,
                device=device, 
                pref_dist=d_omega,
                num_frames_per_proc=args.frames_per_proc, 
                discount=args.discount, 
                lr=args.lr, 
                gae_lambda=args.gae_lambda,
                entropy_coef=args.entropy_coef, 
                value_loss_coef=args.value_loss_coef, 
                max_grad_norm=args.max_grad_norm,
                adam_eps=args.optim_eps, 
                clip_eps=args.clip_eps, 
                epochs=args.epochs, 
                batch_size=args.batch_size, 
                preprocess_obss=preprocess_obs,
                action_dim=action_dim,
                obs_clip=args.obs_clip,
                epsilon=args.epsilon,
                adaptive=not args.nA)
                        
    if "optimizer_state" in status:
        agent.optimizer.load_state_dict(status["optimizer_state"])
        txt_logger.info("Agent Optimizer loaded\n")

    # Train model
    num_frames = status["num_frames"]
    update = status["update"]
    start_time = time.time()

    while num_frames < args.frames:
        update += 1

        # Train agent
        update_start_time = time.time()
        exps, logs1 = agent.collect_experiences()
        logs2 = agent.update_parameters(exps)
        logs = {**logs1, **logs2}
        update_end_time = time.time()
        num_frames += logs["num_frames"]

        # Print logs
        if update % args.log_interval == 0:
            utils.handle_logs(logs=logs,
                    txt_logger=txt_logger,
                    tb_writer=tb_writer,
                    csv_file=csv_file,
                    csv_logger=csv_logger,
                    update=update, 
                    start_time=start_time,
                    num_frames=num_frames, 
                    status=status,
                    update_start_time=update_start_time, 
                    update_end_time=update_end_time)

        # Save status
        if args.save_interval > 0 and update % args.save_interval == 0:
            status = {"num_frames": num_frames, 
                    "update": update,
                    "model_state": model.state_dict(), 
                    "actor_optimizer_state": agent.optimizer_actor.state_dict(), 
                    "critic_optimizer_state": agent.optimizer_critic.state_dict()}                
            utils.save_status(status, model_dir)
            txt_logger.info("Status saved")

    agent.env.kill()
    del agent
    gc.collect()