import os
import numpy as np
import click
import json
import torch
import torch.optim as optim

from rlkit.envs import ENVS
from rlkit.envs.wrappers import NormalizedBoxEnv
from rlkit.torch.sac.sr_sac2 import SRSoftActorCritic2
from rlkit.torch.networks import shared_policy2
from configs.default import default_config
import wandb
from tqdm import tqdm

class launcher():
    def __init__(self,variant,debug):
        self.env_name = variant['env_name']
        self.env = NormalizedBoxEnv(ENVS[variant['env_name']](**variant['env_params']))
        self.env.set_train_task(variant['n_train_tasks'])

        self.tasks = self.env.get_all_task_idx()
        self.train_tasks = list(self.tasks[:variant['n_train_tasks']])

        self.obs_dim = int(np.prod(self.env.observation_space.shape))
        self.action_dim = int(np.prod(self.env.action_space.shape))

        # device
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  

        # debug
        self.debug = debug

        # sr
        self.no_sr = variant['sr_params']['no_sr']
        self.total_step = variant['sr_params']['total_step']
        self.eval_step = variant['sr_params']['eval_step']

        # instantiate networks
        self.latent_action_dim = variant['sr_params']['latent_action_dim']
        self.net_size = variant['net_size']
        self.agents = []
        self.shared_network = shared_policy2(self.net_size,
                                             self.latent_action_dim).to(self.device)
        self.shared_optimizer = optim.Adam(self.shared_network.parameters(),
                                           lr=variant['algo_params']['policy_lr'])
        for _ in self.train_tasks:
            agent = SRSoftActorCritic2(obs_dim=self.obs_dim,
                                      action_dim=self.action_dim,
                                      net_size=self.net_size,
                                      no_sr = self.no_sr,
                                      latent_action_dim=self.latent_action_dim,
                                      device=self.device,
                                      **variant['algo_params'])
            if self.no_sr:
                agent.policy.shared_layer.requires_grad_(True)
            else:
                agent.load_shared_network(self.shared_network)
            self.agents.append(agent)
        if self.debug:
            pass
        else:
            if self.no_sr:
                wandb.init(
                project = 'Meta RL',
                name = 'SAC training'
                )
            else:
                wandb.init(
                project = 'Meta RL',
                name = f'SAC meta training({self.env_name})'
                )
        
        self.steps = 0


    def learn(self):
        iters = int(self.total_step/(self.agents[0].max_path_length))
        for step in tqdm(range(iters)):
            for idx in self.train_tasks:
                agent = self.agents[idx]
                self.env.reset_task(idx)
                agent.collect_data_and_train_filter(self.env)
                if not self.debug:
                    wandb.log({f"Task {idx} alpha":agent.alpha},step=self.steps)
                self.steps += agent.max_path_length
            if not (self.no_sr):
                self.meta_update()
            if self.steps % self.eval_step == 0:
                ret = self.eval()
                if not self.debug:
                    if not (self.no_sr):
                        if not os.path.exists(f'./shared_policy/{self.env_name}'):
                            os.makedirs(f'./shared_policy/{self.env_name}')
                        torch.save(self.shared_network.state_dict(),f'./shared_policy/{self.env_name}/shared_policy({self.steps}).pt')
                    for i in self.train_tasks:
                        wandb.log({f"Task {i} returns":ret[i]},step=self.steps)

    def meta_update(self):
        for _ in range(min(self.agents[0].update_step,10)):
            meta_losses = 0
            for agent in self.agents:
                if agent.replay_buffer.size() > agent.batch_size:
                    agent.compute_meta_loss(self.shared_network)
                    meta_loss = agent.get_meta_loss()
                    if not (meta_loss == None):
                        meta_losses += meta_loss
            if not (meta_losses == 0):
                meta_losses /= len(self.agents)
                self.shared_optimizer.zero_grad()
                meta_losses.backward()
                self.shared_optimizer.step()
        try:
            if not self.debug:
                wandb.log({f"Meta Loss":meta_losses.item()},step=self.steps)
        except:
            pass


    def eval(self):
        returns = []
        for idx in self.train_tasks:
            agent = self.agents[idx]
            self.env.reset_task(idx)
            ret = agent.evaluation(self.env)
            returns.append(ret)
        return returns


def deep_update_dict(fr, to):
    ''' update dict of dicts with new values '''
    # assume dicts have same keys
    for k, v in fr.items():
        if type(v) is dict:
            deep_update_dict(v, to[k])
        else:
            to[k] = v
    return to

@click.command()
@click.option('--train_env',default=None)
@click.option('--debug', is_flag=True, default=False)
# @click.option('--no_sr', is_flag=True, default=False)

def main(train_env,debug):
    variant = default_config
    config = f'./configs/{train_env}.json'
    if config:
        with open(os.path.join(config)) as f:
            exp_params = json.load(f)
        variant = deep_update_dict(exp_params, variant)
    launch = launcher(variant,debug=debug)
    launch.learn()

if __name__ == "__main__":
    main()