from collections import OrderedDict

import os
import time
import numpy as np
import torch
import torch.optim as optim
from torch import nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

import rlkit.torch.pytorch_util as ptu
from rlkit.core.eval_util import create_stats_ordered_dict
from rlkit.torch.torch_rl_algorithm import TorchTrainer
from rlkit.samplers.data_collector.path_collector import MdpPathCollector

class CQLTrainer(TorchTrainer):
    def __init__(
            self,
            env,
            exp_name,
            policy,
            qf1,
            weight_net,
            target_weight_net,

            qf_lr=1e-3,
            optimizer_class=optim.Adam,

            temp=1.0,
            num_total=10,
            diff_clip=100,
            lambda_reg=1e-2,

            discount=0.99,
    ):
        super().__init__()
        self.env = env
        self.exp_name = exp_name
        self.policy = policy
        self.qf1 = qf1
        self.weight_net = weight_net
        self.target_weight_net = target_weight_net

        self.qf_criterion = nn.MSELoss()

        self.w_optimizer = optimizer_class(
            self.weight_net.parameters(),
            lr=qf_lr,
        )

        self.lambda_reg = lambda_reg
        print("self.lambda_reg: \t", lambda_reg)
        self.discount = discount

        self.eval_statistics = OrderedDict()
        self.eval_wandb = OrderedDict()
        self._n_train_steps_total = 0
        self._need_to_update_eval_statistics = True

        self._current_epoch = 0

        ## min Q
        self.temp = temp
        print('self.temp: \t', temp)
        self.num_total = num_total
        print('self.num_total: \t', self.num_total)
        self.diff_clip = diff_clip
        print('self.diff_clip: \t', self.diff_clip)

        self.path = f'./{self.exp_name}'
        if not os.path.exists(self.path):
            os.makedirs(self.path)

        self.discrete = False

    def _get_policy_actions(self, obs, num_actions, network=None):
        obs_temp = obs.unsqueeze(1).repeat(1, num_actions, 1).view(obs.shape[0] * num_actions, obs.shape[1])
        new_obs_actions, _, _, new_obs_log_pi, *_ = network(
            obs_temp, reparameterize=False, return_log_prob=True,
        )
        if not self.discrete:
            return new_obs_actions, new_obs_log_pi.view(obs.shape[0], num_actions, 1)
        else:
            return new_obs_actions

    # def _calculate_weight(self, idx, replay_buffer, cluster_idx_list):
    #
    #     dist_ratio = np.zeros((idx.shape[0], 1), dtype=float)
    #     for i in range(idx.shape[0]):
    #         N D
            # nearest_obs = ptu.from_numpy(replay_buffer._observations[i]).unsqueeze(0)
            # nearest_nobs = ptu.from_numpy(replay_buffer._next_obs[i]).unsqueeze(0)
            # nearest_nobs = ptu.from_numpy(replay_buffer._next_obs[i]).unsqueeze(0)
            # feature = torch.cat([nearest_obs, nearest_nobs], dim=-1)
            #
            # N N -> N 1
            # dist = torch.cdist(nearest_obs, nearest_nobs, p=2).sum(dim=-1, keepdim=True)
            # print("dist")
            # print(dist)
            # ratio = dist.mean()
            # print("ratio")
            # print(ratio)
            # print()
            #
            #
            # dist_ratio[i] = ptu.get_numpy(ratio)

        # return dist_ratio

    def train_from_torch(self, batch):
        self._current_epoch += 1

        obs = batch['observations']
        terminals = batch['terminals']
        actions = batch['actions']
        next_obs = batch['next_observations']
        diff = batch['diff'].clamp(max=self.diff_clip) # For weight network stability

        """
        Weight net
        """
        w_pred = self.weight_net(obs, actions)

        with torch.no_grad():
            weight_next_actions, *_ = self.policy(
                next_obs, reparameterize=True, return_log_prob=True,
            )
            target_w_pred = self.target_weight_net(next_obs, weight_next_actions)
            w_target = diff + (1. - terminals) * self.discount * target_w_pred

        net_error = self.qf_criterion(w_pred, w_target.detach())

        l2_reg = 0.0
        for param in self.weight_net.parameters():
            l2_reg += torch.norm(param, p=2) ** 2

        w_loss = (net_error + self.lambda_reg * l2_reg).mean()

        self.w_optimizer.zero_grad()
        w_loss.backward(retain_graph=True)
        self.w_optimizer.step()

        """
        Save some statistics for eval
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """
            self.eval_statistics.update(create_stats_ordered_dict(
                'Batch Diff',
                ptu.get_numpy(diff),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'W Pred',
                ptu.get_numpy(w_pred),
            ))
            self.eval_statistics['W target'] = np.mean(ptu.get_numpy(target_w_pred))
            self.eval_statistics['W net Error'] = np.mean(ptu.get_numpy(net_error))
            self.eval_statistics['l2 reg'] = l2_reg.item()
            self.eval_statistics['Weight net loss'] = w_loss.item()

        self._n_train_steps_total += 1

    def get_diagnostics(self):
        return self.eval_statistics

    def end_epoch(self, epoch):
        self._need_to_update_eval_statistics = True

    @property
    def networks(self):
        base_list = [
            self.policy,
            self.qf1,
            self.weight_net,
            self.target_weight_net,
        ]
        return base_list

    def get_snapshot(self):
        return dict(
            weight_net = self.weight_net,
        )

    def set_snapshot(self, snapshot):
        self.weight_net = snapshot['weight_net']

