from collections import OrderedDict
import numpy as np
import os
import os.path
from os import path
import torch
import torch.optim as optim
from torch import nn as nn
import rlkit.torch.pytorch_util as ptu
from rlkit.core.eval_util import create_stats_ordered_dict
from rlkit.core.rl_algorithm import MetaRLAlgorithm
import numpy as np
import warnings


"""
SAC implementation is from rlkit via github
https://github.com/vitchyr/rlkit

Meta learning SAC sample implementation from Rakelly et. al. Pearl on Github
https://github.com/katerakelly/oyster
"""

class Flap(MetaRLAlgorithm):
    def __init__(
            self,
            env,
            train_tasks,
            eval_tasks,
            q1_net,
            q2_net,
            policy,
            target_q1,
            target_q2,
            test_policy,
            adapter,
            tasks_num,
            gpu_id,
            adapt_steps=200,
            policy_lr=3e-4,
            qf_lr=3e-4,
            optimizer_class=optim.Adam,
            soft_target_tau=5e-3,
            plotter=None,
            render_eval_paths=False,
            target_entropy=None,
            **kwargs
    ):
        super().__init__(
            env=env,
            policy=policy,
            train_tasks=train_tasks,
            eval_tasks=eval_tasks,
            adapter=adapter,
            test_policy=test_policy,
            gpu_id=gpu_id,
            adapt_steps=adapt_steps,
            **kwargs
        )
        self.policy_lr=policy_lr
        self.qf_lr=qf_lr
        self.soft_target_tau = soft_target_tau
        self.plotter = plotter
        self.render_eval_paths = render_eval_paths
        self.target_entropy = target_entropy
        self.tasks_num=tasks_num
        self.gpu_id=gpu_id
        self.adapter = adapter
        self.test_policy = test_policy

        self.log_alpha = ptu.zeros(1, requires_grad=True)
        self.alpha_optimizer = optimizer_class(
            [self.log_alpha],
            lr=policy_lr,
        )
        
        self.target_entropy = -np.prod(env.action_space.shape).item()  # heuristic value from Tuomas from rlkit
        self.qf_criterion = nn.MSELoss(reduction = 'sum')

        self.policy_optimizer = optimizer_class(
            policy.parameters(),
            lr=policy_lr,
        )
        
        self.qf1_optimizer = optimizer_class(
            q1_net.parameters(),
            lr=qf_lr,
        )
        
        self.qf2_optimizer = optimizer_class(
            q2_net.parameters(),
            lr=qf_lr
        )
        
        self.adapter_optimizer = optimizer_class(
            adapter.parameters(),
            lr=qf_lr
        )
            
        self.q1_net=q1_net
        self.q2_net=q2_net
        self.policy=policy
        self.target_q1=target_q1
        self.target_q2=target_q2
        
    ###### Torch stuff #####
    @property
    def networks(self):
        return [self.policy, self.q1_net, self.q2_net, self.target_q1, self.target_q2, self.adapter, self.test_policy]

    def training_mode(self, mode):
        for net in self.networks:
            net.train(mode)

    def to(self, device=None):
        if device == None:
            device = ptu.device
        for net in self.networks:
            net.to("cuda:"+str(self.gpu_id))

    ##### Data handling #####
    def unpack_batch(self, batch, sparse_reward=False):
        ''' unpack a batch and return individual elements '''
        o = batch['observations'][None, ...]
        a = batch['actions'][None, ...]
        if sparse_reward:
            r = batch['sparse_rewards'][None, ...]
        else:
            r = batch['rewards'][None, ...]
        no = batch['next_observations'][None, ...]
        t = batch['terminals'][None, ...]
        return [o, a, r, no, t]

    def sample_sac(self, indices):
        ''' sample batch of training data from a list of tasks for training the actor-critic '''
        # this batch consists of transitions sampled randomly from replay buffer
        # rewards are always dense
        batches = [ptu.np_to_pytorch_batch(self.replay_buffer.random_batch(idx, batch_size=self.batch_size)) for idx in indices]
        unpacked = [self.unpack_batch(batch) for batch in batches]
        # group like elements together
        unpacked = [[x[i] for x in unpacked] for i in range(len(unpacked[0]))]
        unpacked = [torch.cat(x, dim=0) for x in unpacked]
        return unpacked

    ##### Training #####
    def _do_training(self):
        self._take_step()

    def _update_target_network(self):
        ptu.soft_update_from_to(
            self.q1_net, self.target_q1, self.soft_target_tau
        )
        ptu.soft_update_from_to(
            self.q2_net, self.target_q2, self.soft_target_tau
        )

    def _take_step(self):
        warnings.filterwarnings("ignore")
        
        num_tasks = 1
        policy_LossCombined = None
        qf1_LossCombined = None
        qf2_LossCombined = None
        total = None
        adapter_LossCombined = None
        
        
        for i in self.train_tasks:
            # data is (task, batch, feat)
            obs, actions, rewards, next_obs, terms = self.sample_sac([i])
            
            # flattens out the task dimension
            t, b, _ = actions.size()
            obs = obs.view(t * b, -1)

            actions = actions.view(t * b, -1)
            next_obs = next_obs.view(t * b, -1)
            rewards_flat = rewards.view(self.batch_size * num_tasks, -1)
            
            prediction = self.adapter(obs, actions, rewards_flat, next_obs)
            weights = torch.flatten(self.policy.state_dict()['lc' + str(i) + '.weight'])
            bias = torch.flatten(self.policy.state_dict()['lc' + str(i) + '.bias'])
            weights_bias_concat = torch.cat((weights, bias))
            adapter_loss = self.qf_criterion(prediction, weights_bias_concat)
            
            if adapter_LossCombined is None:
                adapter_LossCombined = adapter_loss
            else:
                adapter_LossCombined = adapter_LossCombined + adapter_loss
           
            # run inference in networks
            policy_outputs = self.policy(obs, i, reparameterize=True, return_log_prob=True)
            new_actions, policy_mean, policy_log_std, log_pi = policy_outputs[:4]
            
            
            if torch.cuda.device_count():
                alpha_loss = -(self.log_alpha.cuda(self.gpu_id) * (log_pi + self.target_entropy).detach()).mean()
                self.alpha_optimizer.zero_grad()
                alpha_loss.backward()
                self.alpha_optimizer.step()
                alpha = self.log_alpha.cuda(self.gpu_id).exp()
            else:
                alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()
                self.alpha_optimizer.zero_grad()
                alpha_loss.backward()
                self.alpha_optimizer.step()
                alpha = self.log_alpha.exp()
            
            # compute min Q on the new actions
            min_q_new_actions = torch.min(
                self.q1_net(i, obs, new_actions),
                self.q2_net(i, obs, new_actions),
            )
            
            if torch.cuda.device_count():
                policy_loss = (
                        alpha.cuda(self.gpu_id)*log_pi - min_q_new_actions
                )
            else:
                policy_loss = (
                        alpha*log_pi - min_q_new_actions
                )
            

            if policy_LossCombined is None:
                policy_LossCombined = policy_loss
            else:
                policy_LossCombined = torch.cat((policy_LossCombined, policy_loss), dim=0)
            

            q1_pred = self.q1_net(i, obs, actions)
            q2_pred = self.q2_net(i, obs, actions)
            
            policy_outputs_target = self.policy(next_obs, i, reparameterize=True, return_log_prob=True)
            next_new_actions, _, _, new_log_pi = policy_outputs_target[:4]

            if torch.cuda.device_count():
                target_q = torch.min(
                    self.target_q1(i, next_obs, next_new_actions),
                    self.target_q2(i, next_obs, next_new_actions),
                ) - alpha.cuda(self.gpu_id) * new_log_pi
            else:
                target_q = torch.min(
                    self.target_q1(i, next_obs, next_new_actions),
                    self.target_q2(i, next_obs, next_new_actions),
                ) - alpha * new_log_pi
            
     
            rewards_flat = rewards.view(self.batch_size * num_tasks, -1)
            
            # scale rewards for Bellman update
            rewards_flat = rewards_flat * self.reward_scale
            terms_flat = terms.view(self.batch_size * num_tasks, -1)
            q_target = rewards_flat + (1. - terms_flat) * self.discount * target_q
            
            qf1_loss = self.qf_criterion(q1_pred, q_target.detach())
            qf2_loss = self.qf_criterion(q2_pred, q_target.detach())
            
            if total is None:
                total = q1_pred.shape[0]
            else:
                total = total + q1_pred.shape[0]
            
            if qf1_LossCombined is None:
                qf1_LossCombined = qf1_loss
            else:
                qf1_LossCombined = qf1_LossCombined + qf1_loss
                
            if qf2_LossCombined is None:
                qf2_LossCombined = qf2_loss
            else:
                qf2_LossCombined = qf2_LossCombined + qf2_loss
            

            #eval statistics for one task
            #used for debugging to determing
            #if algorithm works
            if self.eval_statistics is None:
                # eval should set this to None.
                # this way, these statistics are only computed for one batch.
                self.eval_statistics = OrderedDict()
                self.eval_statistics['Fitting Weight'] = np.mean(ptu.get_numpy(
                    adapter_loss
                ))
                self.eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy(
                    policy_loss
                ))
                self.eval_statistics.update(create_stats_ordered_dict(
                    'Q Predictions',
                    ptu.get_numpy(q1_pred),
                ))

        #Mean loss over all training tasks
        policy_LossCombined = policy_LossCombined.mean()
        qf1_LossCombined = qf1_LossCombined/total
        qf2_LossCombined = qf2_LossCombined/total

        self.adapter_optimizer.zero_grad()
        adapter_LossCombined.backward()
        self.adapter_optimizer.step()

        self.policy_optimizer.zero_grad()
        policy_LossCombined.backward(retain_graph=True)
        self.policy_optimizer.step()
        

        self.qf1_optimizer.zero_grad()
        qf1_LossCombined.backward(retain_graph=True)
        self.qf1_optimizer.step()

        self.qf2_optimizer.zero_grad()
        qf2_LossCombined.backward(retain_graph=True)
        self.qf2_optimizer.step()

        self._update_target_network()
        

    def get_epoch_snapshot(self, epoch):
        # NOTE: overriding parent method which also optionally saves the env
        snapshot = OrderedDict(
            qf1=self.q1_net.state_dict(),
            qf2=self.q2_net.state_dict(),
            policy=self.policy.state_dict(),
        )
        
        snapshot["adapter"] = self.adapter.state_dict()
        
        return snapshot
