import os

os.environ["MUJOCO_GL"] = os.environ.get("MUJOCO_GL", "egl")

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import argparse
from gymnasium.vector import SyncVectorEnv, AsyncVectorEnv
from PIL import Image
import yaml
import gymnasium as gym

from utils.logger import Logger
from envs.noises import *
from envs.state_image_depth_env import make_sid_env

from preprocessors.linear_comb import LinearComb
from preprocessors.concatenation import ConCat
from preprocessors.curl import Curl
from preprocessors.mmm import MMM
from preprocessors.gmc import GMC
from preprocessors.amdf import AMDF
from preprocessors.coral import CORAL

from rl_algos.sac import SAC
from utils.loops import train

if __name__ == "__main__":

    parser = argparse.ArgumentParser()

    parser.add_argument('--algo', default=1, type=int)
    parser.add_argument('--z_dim', default=64, type=int)
    parser.add_argument('--seed', default=0, type=int)
    parser.add_argument('--env_id', default=1, type=int)

    args = parser.parse_args()

    seed = args.seed
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.random.manual_seed(seed)
    torch.cuda.manual_seed(seed)

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    env_encodings = {
        0: {"name": "Ant-v5", "max_R": 6000},
        1: {"name": "HalfCheetah-v5", "max_R": 10000},
        2: {"name": "Hopper-v5", "max_R": 3500},
        3: {"name": "Humanoid-v5", "max_R": 6500},
        4: {"name": "Walker2d-v5", "max_R": 5500},
        5: {"name": "InvertedPendulum-v5", "max_R": 1000},
    }
    env_name = env_encodings[args.env_id]["name"]
    env_max_R = env_encodings[args.env_id]["max_R"]

    algo_encodings = {
        0: "LinearComb",
        1: "ConCat",
        2: "Curl",
        3: "MMM",
        4: "GMC",
        5: "AMDF",
        6: "CORAL"
    }
    algo_name = algo_encodings[args.algo]

    name_exp = ''
    for k, v in args.__dict__.items():
        if k == 'algo':
            name_exp += algo_name + "_"
        else:
            name_exp += str(k) + "=" + str(v) + "_"

    ####################################################### ENV DEF #######################################################

    modalities = "all"
    noises = [gaussian_noise]
    noises_encoding = {
        gaussian_noise: "0",
        salt_and_pepper_noise: "1",
        patches_noise: "2",
        puzzle_noise: "3",
        sensor_failure: "4",
        texture_noise: "5",
        hallucination_noise: "6",
    }

    with open('configs/rl.yml', 'r') as file:
        configs_rl = yaml.safe_load(file)

    envs = SyncVectorEnv([make_sid_env(env_name, "all", noises=noises, p_noise=0.0) for _ in range(configs_rl['n_workers'])])
    env_test = make_sid_env(env_name, "all", noises=noises, p_noise=0.0)()

    state_dim = env_test.observation_space.spaces['state'].shape[0]
    action_dim = env_test.action_space.shape[0]
    img_dim = env_test.observation_space['image'].shape[-1] if env_test._render_mode in ['image', 'all'] else 0
    action_bounds = np.stack([env_test.action_space.low, env_test.action_space.high], 0)
    max_T = env_test.env.spec.max_episode_steps

    configs_rl['architecture']['action_bounds'] = torch.from_numpy(action_bounds).float().to(device)

    ####################################################### MODEL DEF #######################################################

    rl_algo_name = 'sac'

    frq_training = configs_rl[rl_algo_name]['frq_training']


    if args.algo == 0:
        preprocessor = LinearComb(True, state_dim, img_dim, args.z_dim, action_dim, modalities, configs_rl, device).to(device)
    elif args.algo == 1:
        preprocessor = ConCat(True, state_dim, img_dim, args.z_dim, action_dim, modalities, configs_rl, device).to(device)
    elif args.algo == 2:
        preprocessor = Curl(True, state_dim, img_dim, args.z_dim, action_dim, modalities, configs_rl, device).to(device)
    elif args.algo == 3:
        preprocessor = MMM(True, state_dim, img_dim, args.z_dim, action_dim, modalities, configs_rl, device).to(device)
    elif args.algo == 4:
        preprocessor = GMC(True, state_dim, img_dim, args.z_dim, action_dim, modalities, configs_rl, device).to(device)
    elif args.algo == 5:
        preprocessor = AMDF(True, state_dim, img_dim, args.z_dim, action_dim, modalities, configs_rl, device).to(device)
    elif args.algo == 6:
        preprocessor = CORAL(True, state_dim, img_dim, args.z_dim, action_dim, modalities, configs_rl, device).to(device)
    else:
        preprocessor = None
        algo_name = ""
        print("not implemented yet")
        exit()

    all_dims = {
        's': state_dim,
        'o': img_dim,
        'd': img_dim,
        'a': action_dim,
        'z': preprocessor.z_dim,
        'time': 3
    }

    agent = SAC(preprocessor, all_dims, configs_rl, device).to(device)

    ####################################################### TRAIN #######################################################

    file_name = env_name + "_" + modalities + "_" + algo_name + "_train" + "_seed=" + str(seed)

    logger = Logger(name_exp, "mujoco", rl_algo_name, "")

    train(envs, env_test, agent, file_name, max_T, device, False, logger, env_max_R, frq_training, True, int(1e6))

    print()













