import numpy as np
from experiments.mujoco_experiments.qarl_experiment import QARLExperiment
from experiments.mujoco_experiments.qarl_homogeneous_experiment import QARLHomogeneousExperiment
from experiments.mujoco_experiments.qarl_linear_experiment import QARLLinearExperiment
from experiments.mujoco_experiments.qarl_point_experiment import QARLPointExperiment
from experiments.mujoco_experiments.qarl_single_experiment import QARLSingleExperiment
from experiments.mujoco_experiments.baseline_experiment import BaselineExperiment
from experiments.mujoco_experiments.rarl_experiment import RARLExperiment
from experiments.mujoco_experiments.fix_rarl_experiment import Fix_RARLExperiment
from experiments.mujoco_experiments.force_experiment import ForceExperiment
from experiments.mujoco_experiments.mas_experiment import MASExperiment
from experiments.mujoco_experiments.sgld_experiment import SGLDExperiment
from torch import manual_seed
from pathlib import Path
import torch

def experiment(algorithm: str = "", **kwargs):
    seed = kwargs["seed"]
    np.random.seed(seed)
    manual_seed(seed)

    # Mujoco Experiment
    if algorithm == "baseline":
        experiment = BaselineExperiment(**kwargs)
    elif algorithm == "fix_rarl":
        experiment = Fix_RARLExperiment(**kwargs)
    elif algorithm == "rarl":
        experiment = RARLExperiment(**kwargs)
    elif algorithm == "qarl":
        experiment = QARLExperiment(**kwargs)
    elif algorithm == "qarl_homogeneous":
        experiment = QARLHomogeneousExperiment(**kwargs)
    elif algorithm == "qarl_linear":
        experiment = QARLLinearExperiment(**kwargs)
    elif algorithm == "qarl_point":
        experiment = QARLPointExperiment(**kwargs)
    elif algorithm == "qarl_single":
        experiment = QARLSingleExperiment(**kwargs)
    elif algorithm == "mas":
        experiment = MASExperiment(**kwargs)
    elif algorithm == "sgld":
        experiment = SGLDExperiment(**kwargs)
    elif algorithm == "force":
        experiment = ForceExperiment(**kwargs)
    else:
        raise ValueError("Unknown algorithm provided!")

    agents = experiment.train_protagonist()
    # worst_adversary = experiment.train_worst_adversary(protagonist, adversary)
    # experiment.evaluate(protagonist, worst_adversary)

def collect_eval(algorithm: str = "", **kwargs):
    seed = kwargs["seed"]
    results_dir = kwargs["results_dir"]
    np.random.seed(seed)
    manual_seed(seed)

    # Mujoco Experiment
    # Mujoco Experiment
    if algorithm == "baseline":
        experiment = BaselineExperiment(**kwargs)
    elif algorithm == "fix_rarl":
        experiment = Fix_RARLExperiment(**kwargs)
    elif algorithm == "rarl":
        experiment = RARLExperiment(**kwargs)
    elif algorithm == "qarl":
        experiment = QARLExperiment(**kwargs)
    elif algorithm == "qarl_homogeneous":
        experiment = QARLHomogeneousExperiment(**kwargs)
    elif algorithm == "qarl_linear":
        experiment = QARLLinearExperiment(**kwargs)
    elif algorithm == "qarl_point":
        experiment = QARLPointExperiment(**kwargs)
    elif algorithm == "qarl_single":
        experiment = QARLSingleExperiment(**kwargs)
    elif algorithm == "mas":
        experiment = MASExperiment(**kwargs)
    elif algorithm == "sgld":
        experiment = SGLDExperiment(**kwargs)
    elif algorithm == "force":
        experiment = ForceExperiment(**kwargs)
    else:
        raise ValueError("Unknown algorithm provided!")

    protagonist_filename = "Training/exp_" + str(seed) + "_protagonist.zip"
    adversary_filename = "Training/exp_" + str(seed) + "_adversary.zip"
    constant = False # vanilla SAC
    protagonist_path = Path(results_dir) / protagonist_filename
    adversary_path = Path(results_dir) / adversary_filename

    # protagonist, adversary, mdp = experiment.generate_dataset(protagonist_path, adversary_path, constant)

    # Ensure the device (CPU or CUDA)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # Load the protagonist model with the correct device
    # Override the device with map_location, so it can load the model onto the available device
    # protagonist = torch.load(protagonist_path, map_location=device)
    
    protagonist, mdp = experiment.load_models(protagonist_path, adversary_path, constant)
    mdp.reset()

    return protagonist, mdp

def evaluation(algorithm: str = "", **kwargs):
    seed = kwargs["seed"]
    results_dir = kwargs["results_dir"]
    np.random.seed(seed)
    manual_seed(seed)

    # Mujoco Experiment
    # Mujoco Experiment
    if algorithm == "baseline":
        experiment = BaselineExperiment(**kwargs)
    elif algorithm == "fix_rarl":
        experiment = Fix_RARLExperiment(**kwargs)
    elif algorithm == "rarl":
        experiment = RARLExperiment(**kwargs)
    elif algorithm == "qarl":
        experiment = QARLExperiment(**kwargs)
    elif algorithm == "qarl_homogeneous":
        experiment = QARLHomogeneousExperiment(**kwargs)
    elif algorithm == "qarl_linear":
        experiment = QARLLinearExperiment(**kwargs)
    elif algorithm == "qarl_point":
        experiment = QARLPointExperiment(**kwargs)
    elif algorithm == "qarl_single":
        experiment = QARLSingleExperiment(**kwargs)
    elif algorithm == "mas":
        experiment = MASExperiment(**kwargs)
    elif algorithm == "sgld":
        experiment = SGLDExperiment(**kwargs)
    elif algorithm == "force":
        experiment = ForceExperiment(**kwargs)
    else:
        raise ValueError("Unknown algorithm provided!")

    protagonist_filename = "Training/exp_" + str(seed) + "_protagonist.zip"
    adversary_filename = "Training/exp_" + str(seed) + "_adversary.zip"
    constant = True# vanilla SAC
    protagonist_path = Path(results_dir) / protagonist_filename
    adversary_path = Path(results_dir) / adversary_filename
    print('first constant: ', constant)
    # protagonist, adversary, mdp = experiment.generate_dataset(protagonist_path, adversary_path, constant)
    print('load model')

    experiment.evaluate_robustness_change(protagonist_path, adversary_path, n_episodes_per_metric_value=1, constant = True)
    # protagonist, adversary, mdp = experiment.load_models(protagonist_path, adversary_path, constant)
    mdp.reset()


  
    

    # protagonist, adversary = experiment.train_protagonist()
    # worst_adversary = experiment.train_worst_adversary(protagonist, adversary)
    # experiment.evaluate(protagonist, worst_adversary)
