import gymnasium as gym
import argparse
import numpy as np
import torch
import torch.optim as optim
from replay_memory import ReplayMemory
from sample_env import EnvSampler
# from Algorithms.DQN import DQN_algorithm
# from Algorithms.DZN import DZN_algorithm
# from Algorithms.Variational_reward import Variational_algorithm
from Algorithms.Variational_continuous import Variational_cont
# from DZN_network import stable_DZN
# from Agent import stable_DZN_agent
# from Algorithms.c51 import C51
from utils import evaluation
from utils import supervised_training
from utils import exploration
from utils import plot_variational_variance
import matplotlib.pyplot as plt
def readParser():
    parser = argparse.ArgumentParser(description='DZN')
    parser.add_argument('--env_name', default="Swimmer-v4",
                        help='Environment')
    parser.add_argument('--agent', default="SAC",
                        help='Network')
    parser.add_argument('--gamma', default=0.99,
                        help='discount factor for reward (default: 0.99)')
    parser.add_argument('--experiment_num', type=int, default = 10,
                        help='experiment number')
    parser.add_argument('--tau', default=0.005,
                        help='target smoothing coefficient(τ) (default: 0.005)')
    parser.add_argument('--beta', default=-1.0,
                        help='Risk parameter')
    parser.add_argument('--alpha', default=0.2,
                        help='Risk parameter')
    parser.add_argument('--lr', default=0.0003,
                        help='learning rate (default: 0.0003)')
    parser.add_argument('--experience_replay_size', default=1000000,
                        help='size of experience buffer')
    parser.add_argument('--epoch_length', default=1000,
                        help='steps per epoch')
    parser.add_argument('--batch_size', default=128,
                        help='batch size for training policy')
    parser.add_argument('--max_path_length', default=500,
                        help='number of episodes per epoch')
    parser.add_argument('--init_exploration_steps', default=10000,
                        help='number of episodes per epoch')
    parser.add_argument('--num_epoch', default=1000,
                        help='total number of epochs')
    return parser.parse_args()


def experiment(args, env_sampler, trainer, memory):
    # val_memory = ReplayMemory(100000)
    # exploration(env_sampler, val_memory, 100000)
    # test_losses = []
    # test_squared = []
    G = []
    right_r = []
    #Collect initial data
    exploration(env_sampler, memory, args.init_exploration_steps)

    # supervised_training(args, memory, trainer)
    # pre_train_model(args, memory, trainer)

    #Training
    for epochs in range(args.num_epoch):
        for i in range(args.epoch_length):
            cur_state, action, next_state, reward, done, info = env_sampler.sample(trainer)
            # print(cur_state)
            memory.push(cur_state, action, reward, next_state, done)
            flags, reward_mean, r = trainer.optimize_model(args, memory)
        # supervised_training(args, memory, trainer)
        # total, squared = compute_error(args, val_memory, trainer)

        if (epochs+1)%10 == 0:
            mean_return, right = evaluation(args, env_sampler, trainer)
            G.append(mean_return)
            right_r.append(right)
            print(epochs, mean_return, right)
        # if epochs % 10 == 0 and args.experiment_num == 1:
        #     plot_variational_variance(args, env_sampler, trainer, trainer.reward_model)
        #     risky = reward_mean[flags]
        #
        #     plt.plot(np.arange(risky.shape[0]), risky, 'o')
        #     plt.plot(np.arange(risky.shape[0]), r[flags], 'o')
        #     plt.show()
        #
        #     risky = reward_mean[~flags]
        #
        #     plt.plot(np.arange(risky.shape[0]), risky, 'o')
        #     plt.plot(np.arange(risky.shape[0]), r[~flags], 'o')
        #     plt.show()


        # test_losses.append(total)
        # test_squared.append(squared)


        if (epochs + 1) % 10 == 0:
            np.save('./Experiments/{0}/{1}/return_{2}_{3}'.format(args.env_name, args.agent, args.beta, int(args.experiment_num)), G)
            np.save('./Experiments/{0}/{1}/right_{2}_{3}'.format(args.env_name, args.agent, args.beta, int(args.experiment_num)), right_r)
            # np.save('./Experiments/{0}/{1}/test_{2}_{3}'.format(args.env_name, args.agent, int(args.experiment_num),
            #                                                     args.beta), test_losses)
            # np.save('./Experiments/{0}/{1}/squared_{2}_{3}'.format(args.env_name, args.agent, int(args.experiment_num),
            #                                                        args.beta), test_squared)
    print('Complete')


def main(args = None):
    if args is None:
        args = readParser()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(device)

    # env = gym.make(args.env_name)
    env = gym.make(args.env_name, exclude_current_positions_from_observation=False)

    env_sampler = EnvSampler(env, max_path_length=args.max_path_length)

    n_observations = env.observation_space.shape[0]
    # print(n_observations)
    # print(env.observation_space.shape[0])
    # print(env.action_space.shape[0])
    # n_actions = env.action_space.n
    n_actions = env.action_space.shape[0]

    memory = ReplayMemory(args.experience_replay_size)

    if args.agent == 'DQN':
        trainer = DQN_algorithm(args, n_observations, n_actions, device, env)
    elif args.agent == 'Variational':
        trainer = Variational_algorithm(args, n_observations, n_actions, device, env)
    elif args.agent == 'SAC':
        trainer = Variational_cont(args, n_observations, n_actions, device, env)
    elif args.agent == 'DZN':
        trainer = DZN_algorithm(args, n_observations, n_actions, device, env)
    elif args.agent == 'stable_DZN':
        policy_net = stable_DZN(n_observations, n_actions).to(device)
        target_net = stable_DZN(n_observations, n_actions).to(device)
        target_net.load_state_dict(policy_net.state_dict())
        optimizer = optim.AdamW(policy_net.parameters(), lr=args.lr, amsgrad=True)
        trainer = stable_DZN_agent(policy_net, target_net, optimizer, device, env)
    elif args.agent == 'C51':
        trainer = C51(args, n_observations, n_actions, device)
    elif args.agent == 'VarC51':
        trainer = Variational_C51(args, n_observations, n_actions, device)
    else:
        return

    experiment(args, env_sampler, trainer, memory)

if __name__ == '__main__':
    main()