#!/usr/bin/env python3
from typing import Optional
from pathlib import Path
import sys; sys.path.append(str(Path(__file__).parent.parent.resolve()))

from plot_master import main, colors
import wandb

from rpi.scripts.pretraining.experts import (
    cheetah_ppo,
    cheetah_sac,
    walker_ppo,
    walker_sac,
    pendulum_ppo,
    pendulum_sac,
    cartpole_ppo,
    cartpole_sac,
)

def convert(model_infos):
    return [(minfo['policy'], minfo['path']) for minfo in model_infos]

cheetah_ppo = convert(cheetah_ppo)
cheetah_sac = convert(cheetah_sac)
walker_ppo = convert(walker_ppo)
walker_sac = convert(walker_sac)
pendulum_ppo = convert(pendulum_ppo)
pendulum_sac = convert(pendulum_sac)
cartpole_ppo = convert(cartpole_ppo)
cartpole_sac = convert(cartpole_sac)

domain2expert_info = {
    'cheetah-run': [cheetah_ppo[:3], cheetah_ppo[::4][:3], cheetah_sac[:3], cheetah_sac[::4][:3], cheetah_sac[-3:]],
    'walker-walk': [walker_ppo[:3], walker_ppo[::4][:3], walker_sac[:3], walker_sac[::4][:3], walker_sac[-3:]],
    'pendulum-swingup': [pendulum_ppo[:3], pendulum_ppo[::4][:3], pendulum_sac[:3], pendulum_sac[-3:]],
    'cartpole-swingup': [cartpole_ppo[:3], cartpole_ppo[::4][:3], cartpole_sac[:3], cartpole_sac[::4][:3], cartpole_sac[-3:]]}
# domain2expertsteps = {
#     # 'cheetah-run': [[100], [100, 70], [100, 70, 40], [100, 70, 40, 20]],
#     'cheetah-run': [[100, 70, 40]],
#     'walker-walk': [[190, 150, 100, 80], [150, 100, 80, 50], [130, 100, 80, 40]],
#     'pendulum-swingup': [[200, 150], [200, 150, 100], [200, 150, 100, 50]],
#     'cartpole-swingup': [[400, 300, 200, 40], [400, 140, 80], [400, 160, 60]]
# }
domain2ase_sigma = {
    'cheetah-run': 2.5,
    'cartpole-swingup': 0.25,
    'walker-walk': 10,
    'pendulum-swingup': 0.25,
}

ase_sigmas_on_cheetah = [0.5 * (i + 1) for i in range(20)]

def maybe_toint(val):
    if val.is_integer():
        return int(val)
    return val


def get_aps_vs_ase_query_set(
        domain,
        expert_steps,
        algorithms=["lops-aps", "lops-aps-ase", "mamba", "pg-gae"],
        learner_pi=["all"],
        group="original",
        ase_sigmas: Optional[dict] = None,
        alg2group: Optional[dict] = None
):

    domain = f"dmc:{domain.capitalize()}-v1"
    # print("given expert steps", domain, expert_steps)

    # HACK: for the first set of runs, I didn't specify groupname
    group = {'$ne': 'sigmas'} if group == 'original' else {'$eq': group}

    ase_sigmas = {} if ase_sigmas is None else {'config.ase_sigma': {'$in': ase_sigmas}}

    alg2query = {
        'lops-aps': {
            'config.algorithm': {'$eq': 'lops-aps'},
            'config.use_riro_for_learner_pi': {'$in': learner_pi},
            'config.load_expert_step': {'$eq': expert_steps},
            'group': {'$eq': alg2group['lops-aps']}
        },
        'lops-aps-ase': {
            'config.algorithm': {'$eq': 'lops-aps-ase'},
            'config.use_riro_for_learner_pi': {'$eq': 'all'},
            'config.load_expert_step': {'$eq': expert_steps},
            **ase_sigmas,
            'group': {'$eq': alg2group['lops-aps-ase']}
        },
    }
    return {
        "$and": [{
            "$or": [alg2query[alg] for alg in algorithms],
            "$and": [{
                'config.env_name': {'$eq': domain},
            }]
        }]
    }


def get_query_set(
    domain,
    expert_steps,
    algorithms=["lops-aps", "lops-aps-ase", "mamba", "pg-gae"],
    learner_pi=["all"],
    ase_learner_pi=["all"],
    group="original",
    ase_sigmas: Optional[dict] = None,
    aps_ase_extra: dict = {},
):
    domain = f"dmc:{domain.capitalize()}-v1"
    # print("given expert steps", domain, expert_steps)

    # HACK: for the first set of runs, I didn't specify groupname
    # group = {'$ne': 'sigmas'} if group == 'original' else {'$eq': group}
    group = {'$in':group}
    print("group",group)
    # exit()
    
    ase_sigmas = {} if ase_sigmas is None else {'config.ase_sigma': {'$in': ase_sigmas}}

    alg2query = {
        'rpi': {
            'config.algorithm': {'$eq': 'rpi'},
            'config.state_in_distribution': {'$eq': 999999},
            'config.lmd': {'$eq': 0.9},
            'config.use_riro_for_learner_pi': {'$in': 'rollin'},
            # 'config.load_expert_step': {'$eq': expert_steps},
        },
        'lops-aps': {
            'config.algorithm': {'$eq': 'lops-aps'},
            'config.state_in_distribution': {'$eq': 999999},
            'config.lmd': {'$eq': 0.9},
            'config.use_riro_for_learner_pi': {'$in': 'rollin'},
            # 'config.load_expert_step': {'$eq': expert_steps},
        },
        'lops-il': {
            'config.algorithm': {'$eq': 'lops-il'},
            'config.state_in_distribution': {'$eq': 999999},
            'config.lmd': {'$eq': 0.9},
            'config.use_riro_for_learner_pi': {'$in': 'rollin'},
            # 'config.load_expert_step': {'$eq': expert_steps},
        },
        'lops-lambda': {
            'config.algorithm': {'$eq': 'lops-lambda'},
            'config.use_riro_for_learner_pi': {'$in': 'rollin'},
            'config.lmd': {'$eq': 0},
            # 'config.load_expert_step': {'$eq': expert_steps},
        },
        # 'rpi': {
        #     'config.algorithm': {'$eq': 'lops-aps-ase'},
        #     'config.use_riro_for_learner_pi': {'$in': ase_learner_pi},
        #     'config.load_expert_step': {'$eq': expert_steps},
        #     **ase_sigmas,
        #     **aps_ase_extra,
        # },
        'mamba': {
            'config.algorithm': {'$eq': 'mamba'},
            'config.use_riro_for_learner_pi': {'$eq': 'rollin'},
            # 'config.load_expert_step': {'$eq': expert_steps},
        },
        'pg-gae': {
            'config.algorithm': {'$eq': 'pg-gae'},
            'config.use_riro_for_learner_pi': {'$eq': 'rollin'},
            # 'config.load_expert_step': {'$eq': [0]},
        }

    }
    return {
        "$and": [{
            "$or": [alg2query[alg] for alg in algorithms],
            "$and": [{
                'config.env_name': {'$eq': domain},
                'config.group': group,
                # 'created_at': {
                #     "$lt": '2023-02-15T2000',
                # }
                # 'created_at': {
                #     "$lt": '2023-02-17T1000',
                # },
            }],
        }]
    }


if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--user", default='anonymous ', help='wandb user name')
    parser.add_argument('--proj', default='alops-rpi_pendulum_final', help='wandb project name')
    parser.add_argument('--best_oracle', default=52, help='horizontal line for best expert')
    parser.add_argument("--format", default='pdf', choices=['pdf', 'png'], help="pdf or png")
    parser.add_argument("--dry-run", action='store_true')
    parser.add_argument("--force", action='store_true', help='If true, overwrite the existing plot')
    parser.add_argument("--use-stderr", action='store_true', help='If true, use standard error')
    parser.add_argument("--notitle", action='store_true')
    parser.add_argument("--nolegend", action='store_true')
    args = parser.parse_args()

    api = wandb.Api()

    # domains = ['cheetah-run', 'walker-walk', 'pendulum-swingup', 'cartpole-swingup']
    domains = ['pendulum-swingup']
    # domains = ['cheetah-run']

    plot2config = {
        # Desc: main plot that shows performance between ours vs baselines
        # Axes: Training step vs Best-return
        # Domains: each
        # Lines: lops-aps-ase, lops-aps, mamba and pg-gae
        # Experts: each
        **{f"main-plot-p1p3s1-{domain}-{i}": {
            # "query": get_query_set(domain, expert_steps, algorithms=['lops-aps', 'mamba', 'pg-gae'], learner_pi=['all'], group="multi_expert_p1p3s9"),
            # "query": get_query_set(domain, expert_steps, algorithms=['lops-aps', 'mamba', 'lops-lambda'], learner_pi=['all'], group="multi_expert_p1p3s9"),
            "query": get_query_set(domain, expert_steps, algorithms=['rpi','lops-aps', 'lops-il', 'mamba', 'lops-lambda', 'pg-gae'], learner_pi=['rollin'], group=["multi_expert_p1p3s1_final","ppo-gae-rollin"]),
            # "query": get_query_set(domain, expert_steps, algorithms=['lops-aps', 'lops-il', 'mamba', 'lops-lambda'], learner_pi=['rollin'], group="multi_expert_p1p3s9"),
            "xlabel": "Training step",
            "ylabel": "Best return",
            "group_keys": ["group","algorithm", "state_in_distribution","lmd","use_riro_for_learner_pi"],
            "ykey": "eval/best-so-far",
            "xkey": "step",
            "hbar": "expert_vals",
            "group2legend": {
                "multi_expert_p1p3s1_final-mamba-999999-0.9-rollin": "Mamba",
                "multi_expert_p1p3s1_final-lops-aps-999999-0.9-rollin": "LOPS-APS",
                "multi_expert_p1p3s1_final-lops-lambda-999999-0-rollin": "LOKI", #"LOPS-Lambda-0"
                "multi_expert_p1p3s1_final-lops-lambda-999999-0.9-rollin": "LOPS-Lambda-0.9",
                "multi_expert_p1p3s1_final-lops-il-999999-0.9-rollin": "Max-aggregation",#LOPS-IL
                "ppo-gae-rollin-pg-gae-999999-0.9-rollin": "PPO-GAE",
                "multi_expert_p1p3s1_final-rpi-999999-0.9-rollin": "RPI",
                # "mamba-none": "Mamba",
                # "lops-aps-all": "LOPS-APS",
                # "lops-aps-rollin": "LOPS-APS-ri",
                # "lops-aps-ase-all": "LOPS-APS-ASE",
                # "pg-gae-none": "PPO-GAE"
            },
            "group2color": {
                "multi_expert_p1p3s1_final-mamba-999999-0.9-rollin": colors[3],
                "multi_expert_p1p3s1_final-lops-aps-999999-0.9-rollin": colors[1],
                "multi_expert_p1p3s1_final-lops-lambda-999999-0-rollin": colors[2],
                "multi_expert_p1p3s1_final-lops-lambda-999999-0.9-rollin": colors[6],
                "multi_expert_p1p3s1_final-lops-il-999999-0.9-rollin": colors[4],
                "ppo-gae-rollin-pg-gae-999999-0.9-rollin": colors[5],
                "multi_expert_p1p3s1_final-rpi-999999-0.9-rollin": colors[0],
                # "mamba-none": colors[1],
                # "lops-aps-all": colors[0],
                # "lops-aps-rollin": colors[-1],
                # "lops-aps-ase-all": colors[2],
                # "pg-gae-none": colors[2]main
            },
            "show_title": not args.notitle,
            "show_legend": not args.nolegend,
            "plot_dir": "generated/test-plot",
        } for domain in domains
           for i, expert_steps in enumerate(domain2expert_info[domain])
           },
    }

    user = args.user
    project = args.proj
    name = f'{user}/{project}'
    print(name)

    for plot_name, config in plot2config.items():
        print('plot_name', plot_name)
        # print('config', config)
        query = config['query']
        print('query\n', query)
        runs = api.runs(name, query)
        print(runs)
        main(runs, plot_name, config, ext=f'.{args.format}', force=args.force, dry_run=args.dry_run, use_stderr=args.use_stderr, best_expert=args.best_oracle)
