import datetime
import os
import click
import numpy as np
import scipy.stats as st
import torch
import wandb

from cavia import utils
from cavia.baseline import LinearFeatureBaseline
from cavia.metalearner import MetaLearner
from cavia.policies.normal_mlp import NormalMLPPolicy, CaviaMLPPolicy
from cavia.sampler import BatchSampler


def get_returns(episodes_per_task):

    # sum up for each rollout, then take the mean across rollouts
    returns = []
    for task_idx in range(len(episodes_per_task)):
        curr_returns = []
        episodes = episodes_per_task[task_idx]
        for update_idx in range(len(episodes)):
            # compute returns for individual rollouts
            ret = (episodes[update_idx].rewards * episodes[update_idx].mask).sum(dim=0)
            curr_returns.append(ret)
        # result will be: num_evals * num_updates
        returns.append(torch.stack(curr_returns, dim=1))

    # result will be: num_tasks * num_evals * num_updates
    returns = torch.stack(returns)
    returns = returns.reshape((-1, returns.shape[-1]))

    return returns


def total_rewards(episodes_per_task, interval=False):

    returns = get_returns(episodes_per_task).cpu().numpy()

    mean = np.mean(returns, axis=0)
    conf_int = st.t.interval(0.95, len(mean) - 1, loc=mean, scale=st.sem(returns, axis=0))
    conf_int = mean - conf_int
    if interval:
        return mean, conf_int[0]
    else:
        return mean


@click.command()
@click.option('--test_env',default=None)
@click.option('--seed',default=0)
@click.option('--debug', is_flag=True, default=False)


def main(test_env,seed,debug):
    print('starting....')
    if debug:
        logger = None
    else:
        if test_env == 'cheetah-dir':
            logger = wandb.init(project = f'Meta Test cheetah-vel -> cheetah-dir',
                                name = f'CAVIA({seed})',
                                group = 'CAVIA')
        else:
            logger = wandb.init(project = f'Meta Test {test_env}',
                                name=f'CAVIA({seed})',
                                group = 'CAVIA')
    # General
    gamma = 0.95
    tau = 1.0
    first_order = False
    num_context_params = 2

    # Run MAML instead of CAVIA
    maml = False
    hidden_size = 300
    num_layers = 3

    # Testing
    test_freq = 10
    num_test_steps = 5
    test_batch_size = 40
    halve_test_lr = False

    # Task-specific
    fast_batch_size = 20
    fast_lr = 1.0

    # Optimization
    num_batches = 500
    meta_batch_size = 20
    max_kl = 1e-2
    cg_iters = 10
    cg_damping = 1e-5
    ls_max_steps = 15
    ls_backtrack_ratio = 0.8

    # Miscellaneous
    make_deterministic = False
    
    best_return_after = -np.inf
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    utils.set_seed(seed, cudnn=make_deterministic)

    # subfolders for logging
    method_used = 'maml' if maml else 'cavia'
    num_context_params_st = str(num_context_params) + '_' if not maml else ''
    output_name = num_context_params_st + 'lr=' + str(fast_lr) + 'tau=' + str(tau)
    output_name += '_' + datetime.datetime.now().strftime('%d_%m_%Y_%H_%M_%S')

    sampler = BatchSampler(test_env, batch_size=fast_batch_size,
                           device=device, seed=seed)


    if not maml:
        policy = CaviaMLPPolicy(
            int(np.prod(sampler.env.observation_space.shape)),
            int(np.prod(sampler.env.action_space.shape)),
            hidden_sizes=(hidden_size,) * num_layers,
            num_context_params=num_context_params,
            device=device
        )
    else:
        policy = NormalMLPPolicy(
            int(np.prod(sampler.env.observation_space.shape)),
            int(np.prod(sampler.env.action_space.shape)),
            hidden_sizes=(hidden_size,) * num_layers
        )
    if test_env == 'cheetah-dir':
        policy.load_state_dict(torch.load(f'cavia_policy/cheetah-vel/cavia_policy.pt'))
        print('policy loaded')
    else:
        policy.load_state_dict(torch.load(f'cavia_policy/{test_env}/cavia_policy.pt'))
        print('policy loaded')
    # initialise baseline
    baseline = LinearFeatureBaseline(int(np.prod(sampler.env.observation_space.shape)))

    # initialise meta-learner
    metalearner = MetaLearner(sampler, policy, baseline, gamma=gamma, fast_lr=fast_lr, tau=tau,
                              device=device)
    # initialize  train task
    sampler.set_train_task()
    test_task = sampler.set_test_task(test_env,seed)
    metalearner.test(test_task,
                     logger,
                     num_steps=num_test_steps,
                     batch_size=10,
                     halve_lr=halve_test_lr)



if __name__ == '__main__':
    main()
