#!/usr/bin/env python3

import functools
import os
import time
from collections import deque
from copy import deepcopy
from pathlib import Path
from typing import Callable, List

import numpy as np
import torch
import wandb
from rpi import logger
from rpi.agents.base import Agent
from rpi.agents.mamba import (MaxValueFn, StatePredictorEnsemble,
                                ValueEnsemble)
from rpi.agents.ppo import PPOAgent, update_critic_ensemble
from rpi.helpers import set_random_seed, to_torch
from rpi.helpers.data import flatten
from rpi.helpers.env import rollout_single_ep
from rpi.helpers.initializers import ortho_init
from rpi.nn.empirical_normalization import EmpiricalNormalization
from rpi.policies import (GaussianHeadWithStateIndependentCovariance,
                            SoftmaxCategoricalHead)
from rpi.scripts.sweep.default_args import Args
from rpi.value_estimations import (
    _attach_advantage_and_value_target_to_episode,
    _attach_log_prob_to_episodes, _attach_return_and_value_target_to_episode,
    _attach_value_to_episodes)

from .train import Evaluator


def inspect_valuenns(make_env: Callable, experts: List[Agent], evaluator: Evaluator, max_episode_len: int):
    """
    Only train each expert's value prediction model, with the rollouts given by experts

    Evaluation:
    - Plot the predicted mean and stddev of a set of fixed states

    Expectation:
    - mean should converge to somewhere, reflecting the ability of the expert
    - stddev should converge to some value, in the deterministic case, it should go to zero
    """
    env = make_env()

    # HARDCODED
    gamma = Args.gamma
    lambd = Args.lmd

    # For each expert k, collect data D^k by rolling out pi^k
    expert_rollouts = [deque(maxlen=Args.expert_buffer_size) for _ in experts]  # 100 for CartPole, DIP, 2 for HalfCheetah and Ant
    for _ in range(Args.pret_num_rollouts):
        for expert_idx, expert in enumerate(experts):
            episode = rollout_single_ep(env, functools.partial(expert.act, mode=Args.deterministic_experts), max_episode_len)

            if Args.expert_tgtval == 'monte-carlo':
                _attach_return_and_value_target_to_episode(episode, gamma, bootstrap=False)
            elif Args.expert_tgtval == 'monte-carlo-bootstrap':
                _attach_value_to_episodes(experts[expert_idx].vfn, episode, obs_normalizer=experts[expert_idx].obs_normalizer)
                _attach_return_and_value_target_to_episode(episode, gamma, bootstrap=(Args.expert_tgtval=='monte-carlo-bootstrap'))
            elif Args.expert_tgtval == 'gae':
                _attach_value_to_episodes(experts[expert_idx].vfn, episode, obs_normalizer=experts[expert_idx].obs_normalizer)
                _attach_advantage_and_value_target_to_episode(episode, gamma, lambd)
            else:
                raise ValueError(f'Unknown method: {Args.expert_tgtval}')
            expert_rollouts[expert_idx].append(episode)

            expert.obs_normalizer.experience(to_torch([tr['state'] for tr in episode]))

    ref_stateacts = []
    for expert_idx in range(len(experts)):
        first_ep = expert_rollouts[expert_idx][0]
        rand_inds = np.random.choice(len(first_ep), size=(5, ))
        transitions = [first_ep[idx] for idx in rand_inds]
        ref_stateacts += [trans['state'] for trans in transitions]

    # Eval before training
    logs = evaluator.inspect_value_nn(experts, ref_states=[state for state in ref_stateacts])
    wandb.log({**logs, 'step': 0})

    ## Update value function V^k from D^k  (By a simple Monte Carlo return??)
    started = time.perf_counter()
    for expert_idx, expert in enumerate(experts):
        expert_k_transitions = flatten(expert_rollouts[expert_idx])
        # expert.obs_normalizer.experience(to_torch([tr['state'] for tr in expert_k_transitions]))

        pret_step = 0
        for i in range(Args.pret_num_val_iterations):
            print(f'updating critic ensemble {i}/{Args.pret_num_val_iterations}')

            # NOTE: num_updates may change the behavior quite a lot.
            _, loss_critic_history = update_critic_ensemble(expert, expert_k_transitions, num_epochs=max(1, Args.pret_num_epochs // Args.pret_num_val_iterations), batch_size=Args.batch_size, std_from_means=Args.std_from_means)  # 100 for CartPole, DIP
            for loss in loss_critic_history:
                wandb.log({
                    f'pretrain-expert/loss-critic-{expert_idx}': loss,
                    f'pretrain-expert/num-transitions-{expert_idx}': len(expert_k_transitions),
                    'pret-step': pret_step,
                })
                pret_step += 1

            # Recompute the target value and attach them if necessary
            if Args.expert_tgtval == 'monte-carlo':
                pass
            elif Args.expert_tgtval == 'monte-carlo-bootstrap':
                _attach_value_to_episodes(experts[expert_idx].vfn, expert_k_transitions, obs_normalizer=experts[expert_idx].obs_normalizer)
                for episode in expert_rollouts[expert_idx]:
                    _attach_return_and_value_target_to_episode(episode, gamma, bootstrap=(Args.expert_tgtval=='monte-carlo-bootstrap'))
            elif Args.expert_tgtval == 'gae':
                _attach_value_to_episodes(experts[expert_idx].vfn, expert_k_transitions, obs_normalizer=experts[expert_idx].obs_normalizer)
                for episode in expert_rollouts[expert_idx]:
                    _attach_advantage_and_value_target_to_episode(episode, gamma, lambd)
            else:
                raise ValueError()

    elapsed = time.perf_counter() - started
    logs = evaluator.inspect_value_nn(experts, ref_states=[state for state in ref_stateacts])
    wandb.log({'timer/vfn_iterations': elapsed, 'timer/vfn_iter': elapsed / Args.pret_num_val_iterations})
    wandb.log({**logs, 'step': 1})



    # TODO: Eval after training
    # logs = evaluator.inspect_state_predictor(state_predictors=state_predictors, ref_stateacts=ref_stateacts)
    wandb.log(logs)


def main():
    import gym
    from rpi.agents.mamba import MambaAgent

    from .train import Factory, get_expert

    num_train_steps = Args.num_train_steps

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    gamma = Args.gamma  # 0.995
    lmd = Args.lmd  # 0.97
    load_expert_step = Args.load_expert_step

    set_random_seed(Args.seed)

    def _make_env(env_name='DartCartPole-v1', test=False, default_seed=0):
        from rpi.helpers import env
        seed = default_seed if not test else 42 - default_seed
        if env_name.startswith('dmc'):
            extra_kwargs = {'task_kwargs': {'random': seed}}
        else:
            extra_kwargs = {}
        return env.make_env(env_name, seed=seed, **extra_kwargs)

    # def _make_dm_env(test=False, env_name="dmc:Cheetah-run-v1"):
    #     seed = Args.seed if not test else 42 - Args.seed
    #     env = gym.make(env_name, task_kwargs={'random': seed})
    #     from dm_control import suite
    #     assert (domain, task) in suite.BENCHMARKING, f'unknown domain or task specified.\nlist of tasks: {suite.BENCHMARKING}'
    #     seed = Args.seed if not test else 42 - Args.seed
    #     env = suite.load(domain_name=domain, task_name=task, task_kwargs={'random': seed})

        # return env

    make_env = lambda *args, **kwargs: _make_env(Args.env_name, *args, **kwargs)  # TEMP

    test_env = make_env()
    state_dim = test_env.observation_space.low.size


    if isinstance(test_env.action_space, gym.spaces.Box):
        # Continuous action space
        act_dim = test_env.action_space.low.size
        policy_head = GaussianHeadWithStateIndependentCovariance(
            action_size=act_dim,
            var_type="diagonal",
            var_func=lambda x: torch.exp(2 * x),  # Parameterize log std
            var_param_init=0,  # log std = 0 => std = 1
        )
    else:
        # Discrete action space (assuming categorical)
        act_dim = test_env.action_space.n
        policy_head = SoftmaxCategoricalHead()

    logger.info('obs_dim', state_dim)
    logger.info('act_dim', act_dim)

    pi = Factory.create_pi(state_dim, act_dim, policy_head=policy_head)

    obs_normalizer = EmpiricalNormalization(state_dim, clip_threshold=5)
    obs_normalizer.to('cuda')

    # Loading: experts

    experts=[]
    state_predictors = []
    state_pred_optimizers = []
    for idx in load_expert_step:
        expert = get_expert(state_dim, act_dim, deepcopy(policy_head), Path(Args.experts_dir) / test_env.unwrapped.spec.id.lower() / f'step_{idx:06d}.pt',
                        obs_normalizer=None if Args.use_expert_obsnormalizer else obs_normalizer)

        state_predictor = StatePredictorEnsemble(lambda: Factory.create_state_nn(state_dim, act_dim),
                                                 num_state_nns=Args.num_expert_vfns,
                                                 state_dim=state_dim,
                                                 act_dim=act_dim,
                                                 obs_normalizer=EmpiricalNormalization(state_dim, clip_threshold=5),
                                                 std_from_means=Args.std_from_means)
        state_predictor.to('cuda')
        for state_nn in state_predictor.nns:
            ortho_init(state_nn[0], gain=Args.expert_vfn_gain)
            ortho_init(state_nn[2], gain=Args.expert_vfn_gain)
            ortho_init(state_nn[4], gain=Args.expert_vfn_gain)

        experts.append(expert)
        state_predictors.append(state_predictor)
        state_pred_optimizers.append(
            torch.optim.Adam(state_predictor.parameters(), lr=1e-3)
        )

    vfn = MaxValueFn([expert.vfn for expert in experts], obs_normalizers=[expert.obs_normalizer for expert in experts])


    if Args.algorithm == 'pg-gae':
        vfn = Factory.create_vfn(state_dim)
        optimizer = torch.optim.Adam(list(pi.parameters()) + list(vfn.parameters()), lr=1e-3, betas=(0.9, 0.99))
        learner = PPOAgent(pi, vfn, optimizer, obs_normalizer, gamma=Args.gamma, lambd=Args.lmd)
        vfn.to(device)
    else:
        optimizer = torch.optim.Adam(pi.parameters(), lr=1e-3, betas=(0.9, 0.99))
        learner = MambaAgent(pi, vfn, optimizer, obs_normalizer, gamma=gamma, lambd=lmd, use_ppo_loss=Args.use_ppo_loss)
    pi.to(device)
    learner.to(device)

    max_episode_len = 1000
    evaluator = Evaluator(make_env, max_episode_len=max_episode_len)
    inspect_valuenns(make_env, experts, evaluator, max_episode_len=max_episode_len)


if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("sweep_file", help="sweep file")
    parser.add_argument("-l", "--line-number", type=int, help="sweep file")
    args = parser.parse_args()

    # Obtain kwargs from Sweep
    from params_proto.hyper import Sweep
    sweep = Sweep(Args).load(args.sweep_file)
    kwargs = list(sweep)[args.line_number]

    Args._update(kwargs)

    num_gpus = 3
    cvd = args.line_number % num_gpus + 1
    # cvd = np.random.choice(Args.available_gpu)

    os.environ['CUDA_VISIBLE_DEVICES'] = str(cvd)

    sweep_basename = os.path.splitext(os.path.basename(args.sweep_file))[0]
    wandb.login()
    wandb.init(
        # Set the project where this run will be logged
        project='alops-inspect-valnns',
        group=sweep_basename,
        config=vars(Args),
    )
    wandb.run.name= Args.algorithm+"-s"+str(Args.seed)+"-l"+str(Args.lmd)+"-e"+str(Args.load_expert_step) + '-d' + str(Args.deterministic_experts)

    if Args.algorithm == "lops":
        wandb.run.name=wandb.run.name+"-sig"+str(Args.ase_sigma)
    main()
    # wandb.agent(f' anoymous- anoymous/lightrl/{args.sweep_id}', function=main)
