import datetime
import json
import os
import matplotlib.pyplot as plt
import time
import click
import numpy as np
import scipy.stats as st
import torch
import wandb

from cavia import utils
from cavia.baseline import LinearFeatureBaseline
from cavia.metalearner import MetaLearner
from cavia.policies.normal_mlp import NormalMLPPolicy, CaviaMLPPolicy
from cavia.sampler import BatchSampler


def get_returns(episodes_per_task):

    # sum up for each rollout, then take the mean across rollouts
    returns = []
    for task_idx in range(len(episodes_per_task)):
        curr_returns = []
        episodes = episodes_per_task[task_idx]
        for update_idx in range(len(episodes)):
            # compute returns for individual rollouts
            ret = (episodes[update_idx].rewards * episodes[update_idx].mask).sum(dim=0)
            curr_returns.append(ret)
        # result will be: num_evals * num_updates
        returns.append(torch.stack(curr_returns, dim=1))

    # result will be: num_tasks * num_evals * num_updates
    returns = torch.stack(returns)
    returns = returns.reshape((-1, returns.shape[-1]))

    return returns


def total_rewards(episodes_per_task, interval=False):

    returns = get_returns(episodes_per_task).cpu().numpy()

    mean = np.mean(returns, axis=0)
    conf_int = st.t.interval(0.95, len(mean) - 1, loc=mean, scale=st.sem(returns, axis=0))
    conf_int = mean - conf_int
    if interval:
        return mean, conf_int[0]
    else:
        return mean


@click.command()
@click.option('--train_env',default=None)
@click.option('--seed',default=0)
@click.option('--debug', is_flag=True, default=False)


def main(train_env,seed,debug):
    print('starting....')
    if debug:
        pass
    else:
        wandb.init(project = 'Meta RL',name = f'CAVIA meta training({train_env})')

    # General
    gamma = 0.95
    tau = 1.0
    first_order = False
    num_context_params = 2

    # Run MAML instead of CAVIA
    maml = False
    hidden_size = 300
    num_layers = 3

    # Testing
    test_freq = 10
    num_test_steps = 5
    test_batch_size = 40
    halve_test_lr = False

    # Task-specific
    fast_batch_size = 20
    fast_lr = 1.0

    # Optimization
    num_batches = 500
    meta_batch_size = 20
    max_kl = 1e-2
    cg_iters = 10
    cg_damping = 1e-5
    ls_max_steps = 15
    ls_backtrack_ratio = 0.8

    # Miscellaneous
    make_deterministic = False
    
    best_return_after = -np.inf
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    utils.set_seed(seed, cudnn=make_deterministic)

    # subfolders for logging
    method_used = 'maml' if maml else 'cavia'
    num_context_params_st = str(num_context_params) + '_' if not maml else ''
    output_name = num_context_params_st + 'lr=' + str(fast_lr) + 'tau=' + str(tau)
    output_name += '_' + datetime.datetime.now().strftime('%d_%m_%Y_%H_%M_%S')

    sampler = BatchSampler(train_env, batch_size=fast_batch_size,
                           device=device, seed=seed)


    if not maml:
        policy = CaviaMLPPolicy(
            int(np.prod(sampler.env.observation_space.shape)),
            int(np.prod(sampler.env.action_space.shape)),
            hidden_sizes=(hidden_size,) * num_layers,
            num_context_params=num_context_params,
            device=device
        )
    else:
        policy = NormalMLPPolicy(
            int(np.prod(sampler.env.observation_space.shape)),
            int(np.prod(sampler.env.action_space.shape)),
            hidden_sizes=(hidden_size,) * num_layers
        )

    # initialise baseline
    baseline = LinearFeatureBaseline(int(np.prod(sampler.env.observation_space.shape)))

    # initialise meta-learner
    metalearner = MetaLearner(sampler, policy, baseline, gamma=gamma, fast_lr=fast_lr, tau=tau,
                              device=device)
    # initialize  train task
    sampler.set_train_task()
    for batch in range(num_batches):

        # get a batch of tasks
        tasks = sampler.sample_tasks(num_tasks=meta_batch_size)

        # do the inner-loop update for each task
        # this returns training (before update) and validation (after update) episodes
        episodes, inner_losses = metalearner.sample(tasks, first_order=first_order)

        # take the meta-gradient step
        outer_loss = metalearner.step(episodes, max_kl=max_kl, cg_iters=cg_iters,
                                      cg_damping=cg_damping, ls_max_steps=ls_max_steps,
                                      ls_backtrack_ratio=ls_backtrack_ratio)

        # -- logging

        curr_returns = total_rewards(episodes, interval=True)
        return_before = curr_returns[0][0]
        return_after = curr_returns[0][1]
        print(f'[{batch}]return befor update: ', return_before)
        print(f'[{batch}]return after update: ', return_after) 

        if debug:
            pass
        else:
            wandb.log({"Task mean return before":return_before},step=batch)
            wandb.log({"Task mean return after":return_after},step=batch)
        if return_after > best_return_after:
            best_return_after = return_after
            if not os.path.exists(f'./cavia_policy/{train_env}'):
                os.makedirs(f'./cavia_policy/{train_env}')
            torch.save(policy.state_dict(), f'./cavia_policy/{train_env}/cavia_policy.pt')


if __name__ == '__main__':
    main()
