from collections import OrderedDict

import numpy as np
import wandb
import torch
import torch.nn as nn
import torch.optim as optim
import torch.autograd as autograd
import torch.distributions
from torch.optim.lr_scheduler import StepLR

from .base_agent import BaseAgent
from .sac_agent import SACAgent
from .dataset import ReplayBuffer, RandomSampler
from .expert_dataset import ExpertDataset
from networks.discriminator import Discriminator
from utils.info_dict import Info
from utils.logger import logger
from utils.mpi import mpi_average
from utils.pytorch import (
    optimizer_cuda,
    count_parameters,
    sync_networks,
    sync_grads,
    to_tensor,
    mse_dict_tensor,
)
from utils.general import cat_dict_tensor
import algorithms
import copy
from contrastive_learning import Encoder
# from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import numpy as np

class IQLearnAgent(SACAgent):

    def __init__(self, config, ob_space, ac_space, env_ob_space):
        super().__init__(config, ob_space, ac_space, env_ob_space)

        # expert dataset
        if config.is_train:
            self._dataset = ExpertDataset(
                config.demo_path,
                config.demo_subsample_interval,
                ac_space,
                use_low_level=config.demo_low_level,
                sample_range_start=config.demo_sample_range_start,
                sample_range_end=config.demo_sample_range_end,
                num_task=config.num_task,
                target_taskID=config.target_taskID,
                num_target_demos=config.num_target_demos,
                target_demo_path=config.target_demo_path,
            )
            self._data_loader = torch.utils.data.DataLoader(
                self._dataset,
                batch_size=self._config.batch_size,
                shuffle=True,
                drop_last=True,
            )
            self._data_iter = iter(self._data_loader)

    def train(self, step=0):
        # Sample transitions from replay buffer and demonstrations and concatenate
        train_info = Info()

        self._num_updates = 1
        for _ in range(self._num_updates):
            policy_data = self._buffer.sample(self._config.batch_size)
            try:
                expert_data = next(self._data_iter)
            except StopIteration:
                self._data_iter = iter(self._data_loader)
                expert_data = next(self._data_iter)

            _train_info = self._update_network(policy_data, expert_data)
            train_info.add(_train_info)

        return mpi_average(train_info.get_dict(only_scalar=True))

    def preprocess_transitions(self, transitions):
        # pre-process observations
        o, o_next = transitions["ob"], transitions["ob_next"]
        o = self.normalize(o)
        o_next = self.normalize(o_next)

        bs = len(transitions["done"])
        _to_tensor = lambda x: to_tensor(x, self._config.device)
        o = _to_tensor(o)
        o_next = _to_tensor(o_next)
        ac = _to_tensor(transitions["ac"])
        done = _to_tensor(transitions["done"]).reshape(bs, 1).float()
        rew = _to_tensor(transitions["rew"]).reshape(bs, 1)

        return o, o_next, ac, done, rew
    def _update_network(self, policy_data, expert_data):
        # Process policy data and expert data
        info = Info()

        # pre-process observations
        policy_batch = self.preprocess_transitions(policy_data)
        expert_batch = self.preprocess_transitions(expert_data)

        self._update_iter += 1

        critic_train_info = self._update_critic(policy_batch, expert_batch)
        info.add(critic_train_info)

        if self._update_iter % self._config.actor_update_freq == 0:
            o = cat_dict_tensor([policy_batch[0], expert_batch[0]])
            # actor_train_info = self._update_actor_and_alpha(o)
            actor_train_info = self._update_actor_and_alpha(policy_batch[0], expert_batch[0], expert_batch[2])
            info.add(actor_train_info)

        if self._update_iter % self._config.critic_target_update_freq == 0:
            for i, fc in enumerate(self._critic.fcs):
                self._soft_update_target_network(
                    self._critic_target.fcs[i],
                    fc,
                    self._config.critic_soft_update_weight,
                )
            self._soft_update_target_network(
                self._critic_target.encoder,
                self._critic.encoder,
                self._config.encoder_soft_update_weight,
            )

        return info.get_dict(only_scalar=True)

    def getV(self, obs, both=True):
        actions, _, log_pis, _ = self._actor.act(obs, return_log_prob=True, detach_conv=True)
        current_Q1, current_Q2 = self._critic(obs, actions)

        if not both:
            current_V = torch.min(current_Q1, current_Q2) - self._log_alpha.exp().detach() * log_pis
            return current_V

        current_V1 = current_Q1 - self._log_alpha.exp().detach() * log_pis
        current_V2 = current_Q2 - self._log_alpha.exp().detach() * log_pis
        return current_V1, current_V2

    def get_targetV(self, obs, both=True):
        actions, _, log_pis, _ = self._actor.act(obs, return_log_prob=True, detach_conv=True)
        current_Q1, current_Q2 = self._critic_target(obs, actions)

        if not both:
            current_V = torch.min(current_Q1, current_Q2) - self._log_alpha.exp().detach() * log_pis
            return current_V

        current_V1 = current_Q1 - self._log_alpha.exp().detach() * log_pis
        current_V2 = current_Q2 - self._log_alpha.exp().detach() * log_pis
        return current_V1, current_V2

    def _update_critic(self, policy_batch, expert_batch):
        # update critic with IQ-Learn objective
        info = Info()
        policy_o, policy_o_next, policy_ac, policy_done, policy_rew = policy_batch
        expert_o, expert_o_next, expert_ac, expert_done, expert_rew = expert_batch

        if self._config.oldcode:
            num_policy_samples = 0
        else:
            num_policy_samples = len(policy_done)
        all_o = cat_dict_tensor([policy_o, expert_o])
        all_o_next = cat_dict_tensor([policy_o_next, expert_o_next])
        all_ac = cat_dict_tensor([policy_ac, expert_ac])
        all_done = torch.cat([policy_done, expert_done], dim=0)


        if self._config.oldcode:
            with torch.no_grad():
                current_V = self.getV(all_o, both=self._config.separate_V)
        else:
            current_V = self.getV(all_o, both=self._config.separate_V)
        if self._config.iq_learn_use_target:
            with torch.no_grad():
                next_V = self.get_targetV(all_o_next, both=self._config.separate_V)
        else:
            next_V = self.getV(all_o_next, both=self._config.separate_V)

        if self._config.separate_V:
            current_V1, current_V2 = current_V
            next_V1, next_V2 = next_V
        else:
            current_V1, current_V2 = current_V, current_V
            next_V1, next_V2 = next_V, next_V

        # iq loss
        current_Q1, current_Q2 = self._critic(all_o, all_ac)

        # 1st term of IQ loss: -E_(p_expert)[Q(s,a) - yV(s')]
        y1 = (1 - all_done) * self._config.rl_discount_factor * next_V1
        y2 = (1 - all_done) * self._config.rl_discount_factor * next_V2
        reward1 = (current_Q1 - y1)[num_policy_samples:]
        reward2 = (current_Q2 - y2)[num_policy_samples:]
        if self._config.oldcode:
            critic_loss1 = reward1.mean()
            critic_loss2 = reward2.mean()
        else:
            critic_loss1 = - reward1.mean()
            critic_loss2 = - reward2.mean()

        # 2nd term of IQ loss: E_(p)[Q(s,a) - yV(s')], policy and expert states
        value_loss1 = (current_V1 - y1).mean()
        value_loss2 = (current_V2 - y2).mean()
        critic_loss1 += value_loss1
        critic_loss2 += value_loss2

        # 3rd term of IQ loss: 1/4alpha E_(p_expert)[(Q(s,a) - yV(s'))^2]
        chi2_loss1 = 1/(4 * self._config.iq_learn_alpha) * (reward1**2).mean()
        critic_loss1 += chi2_loss1

        chi2_loss2 = 1/(4 * self._config.iq_learn_alpha) * (reward2**2).mean()
        critic_loss2 += chi2_loss2

        critic_loss = 0.5 * (critic_loss1 + critic_loss2)

        # update the critic
        self._critic_optim.zero_grad()
        critic_loss.backward()
        sync_grads(self._critic)
        self._critic_optim.step()

        info["value_loss"] = 0.5*(value_loss1 + value_loss2).cpu().item()
        info["expert_reward"] = 0.5*(reward1 + reward2).mean().cpu().item()
        info["max_expert_reward"] = reward1.abs().max().cpu().item()
        info["critic_loss"] = critic_loss.cpu().item()
        info["chi2_loss"] = 0.5*(chi2_loss1+chi2_loss2).cpu().item()
        info["expert_q"] = 0.5*(current_Q1 + current_Q2)[num_policy_samples:].mean().cpu().item()
        info["policy_q"] = 0.5 * (current_Q1 + current_Q2)[:num_policy_samples].mean().cpu().item()

        info["min_target_q"] = y1.min().cpu().item()
        info["target_q"] = 0.5*(y1+y2).mean().cpu().item()
        info["min_real1_q"] =current_Q1.min().cpu().item()
        info["min_real2_q"] = current_Q2.min().cpu().item()
        info["real1_q"] = current_Q1.mean().cpu().item()
        info["real2_q"] = current_Q2.mean().cpu().item()


        return info

    def _update_actor_and_alpha(self, policy_o, expert_o, expert_a):
        ## Add BC loss

        info = Info()

        o = cat_dict_tensor([policy_o, expert_o])
        actions_real, _, log_pi, _ = self._actor.act(
            o, return_log_prob=True, detach_conv=True
        )
        alpha = self._log_alpha.exp()

        # the actor loss
        entropy_loss = (alpha.detach() * log_pi).mean()
        actor_loss = -torch.min(*self._critic(o, actions_real, detach_conv=True)).mean()

        # BC loss
        if isinstance(actions_real, OrderedDict):
            pred_ac = OrderedDict([(k, actions_real[k][-self._config.batch_size:]) for k in actions_real.keys()])
        else:
            pred_ac = actions_real[-self._config.batch_size:]
        BC_loss = self._config.BC_loss_coeff * mse_dict_tensor(expert_a, pred_ac)

        info["entropy_alpha"] = alpha.cpu().item()
        info["entropy_loss"] = entropy_loss.cpu().item()
        info["actor_loss"] = actor_loss.cpu().item()
        info["BC_loss"] = BC_loss.cpu().item()
        actor_loss += entropy_loss
        actor_loss += BC_loss

        # update the actor
        self._actor_optim.zero_grad()
        actor_loss.backward()
        sync_grads(self._actor)
        self._actor_optim.step()

        # update alpha
        alpha_loss = -(alpha * (log_pi + self._target_entropy).detach()).mean()
        self._alpha_optim.zero_grad()
        alpha_loss.backward()
        self._alpha_optim.step()

        return info
