import copy
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal, TransformedDistribution, constraints
from torch.optim.lr_scheduler import CosineAnnealingLR

from torch.distributions.transforms import Transform

import torch.distributions as td
from typing import Any, Dict, List, Optional, Tuple, Union

from torch.autograd import Variable
from tqdm import trange


class ForwardModel(nn.Module):
    def __init__(self, state_dim, action_dim):
            super(ForwardModel, self).__init__()

            self.l1 = nn.Linear(state_dim + action_dim, 256)
            self.l2 = nn.Linear(256, 256)
            self.l3 = nn.Linear(256, state_dim)

            self.l4 = nn.Linear(256, 1)

    def forward(self, state, action):
            sa = torch.cat([state, action], 1)
            latent = F.relu(self.l1(sa))
            latent = F.relu(self.l2(latent))
            next_state = self.l3(latent)

            reward = self.l4(latent)

            return next_state

class Model(nn.Module):
        def __init__(self, state_dim):
                super(Model, self).__init__()

                self.l1 = nn.Linear(state_dim, 256)
                self.l2 = nn.Linear(256, 256)
                self.l3 = nn.Linear(256, state_dim)

        def forward(self, state):
                latent = F.relu(self.l1(state))
                latent = F.relu(self.l2(latent))
                next_state = self.l3(latent)

                # weight = self.weight(network(state1) - network(state2))
                # next_state = weight * state1 + (1 - weight) * state2

                return next_state

        def weight(self, diff):
            return torch.where(diff<=0, 0, 1)

class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, max_action):
            super(Actor, self).__init__()

            self.l1 = nn.Linear(2 * state_dim, 256)
            self.l2 = nn.Linear(256, 256)
            self.l3 = nn.Linear(256, action_dim)
            self.max_action = max_action
            # self.is_discrete = is_discrete
            

    def forward(self, state, next_state):
            ss = torch.cat([state, next_state], 1)
            a = F.relu(self.l1(ss))
            a = F.relu(self.l2(a))

            # if self.is_discrete:
            #     return torch.nn.Softmax()(self.l3(a))
            
            return self.max_action * torch.tanh(self.l3(a))
                


class TanhTransform(Transform):
    r"""
    Transform via the mapping :math:`y = \tanh(x)`.
    It is equivalent to
    ```
    ComposeTransform([AffineTransform(0., 2.), SigmoidTransform(), AffineTransform(-1., 2.)])
    ```
    However this might not be numerically stable, thus it is recommended to use `TanhTransform`
    instead.
    Note that one should use `cache_size=1` when it comes to `NaN/Inf` values.
    """
    domain = constraints.real
    codomain = constraints.interval(-1.0, 1.0)
    bijective = True
    sign = +1

    @staticmethod
    def atanh(x):
        return 0.5 * (x.log1p() - (-x).log1p())

    def __eq__(self, other):
        return isinstance(other, TanhTransform)

    def _call(self, x):
        return x.tanh()

    def _inverse(self, y):
        # We do not clamp to the boundary here as it may degrade the performance of certain algorithms.
        # one should use `cache_size=1` instead
        return self.atanh(y)

    def log_abs_det_jacobian(self, x, y):
        # We use a formula that is more numerically stable, see details in the following link
        # https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/bijectors/tanh.py#L69-L80
        return 2. * (math.log(2.) - x - F.softplus(-2. * x))


class MLPNetwork(nn.Module):
    
    def __init__(self, input_dim, output_dim, hidden_size=256):
        super(MLPNetwork, self).__init__()
        self.network = nn.Sequential(
                        nn.Linear(input_dim, hidden_size),
                        nn.ReLU(),
                        nn.Linear(hidden_size, hidden_size),
                        nn.ReLU(),
                        nn.Linear(hidden_size, output_dim),
                        )
    
    def forward(self, x):
        return self.network(x)
    


class Policy(nn.Module):

    def __init__(self, state_dim, action_dim, max_action, hidden_size=256):
        super(Policy, self).__init__()
        self.action_dim = action_dim
        self.max_action = max_action
        self.network = MLPNetwork(state_dim, action_dim, hidden_size)

    def forward(self, x):
        mu = self.network(x)
        mean = torch.tanh(mu)
        
        return mean * self.max_action


class DoubleQFunc(nn.Module):
    
    def __init__(self, state_dim, action_dim, hidden_size=256):
        super(DoubleQFunc, self).__init__()
        self.network1 = MLPNetwork(state_dim + action_dim, 1, hidden_size)
        self.network2 = MLPNetwork(state_dim + action_dim, 1, hidden_size)

    def forward(self, state, action):
        x = torch.cat((state, action), dim=1)
        return self.network1(x), self.network2(x)

class ValueFunc(nn.Module):
    
    def __init__(self, state_dim, action_dim, hidden_size=256):
        super(ValueFunc, self).__init__()
        self.network = MLPNetwork(state_dim, 1, hidden_size)

    def forward(self, state):
        return self.network(state)

def asymmetric_l2_loss(u: torch.Tensor, tau: float) -> torch.Tensor:
    return torch.mean(torch.abs(tau - (u < 0).float()) * u**2)


class RewardModel(nn.Module):
    def __init__(self, zs_dim, action_dim):
        super(RewardModel, self).__init__()

        self.r1 = nn.Linear(zs_dim+ action_dim , 256)
        self.r2 = nn.Linear(256, 256)
        self.r3 = nn.Linear(256, 1)

    def forward(self,state,action):

        r = F.relu(self.r1(torch.cat([state,action],1)))
        r = F.relu(self.r2(r))
        r = self.r3(r)
        return r
    
class STC(object):

    def __init__(self,
                 config,
                 device,
                 target_entropy=None,
                 ):
        self.config=  config
        self.device = device
        self.discount = config['gamma']
        self.tau = config['tau']
        self.target_entropy = target_entropy if target_entropy else -config['action_dim']
        self.update_interval = config['update_interval']
        self.threshold = config['threshold']
        self.r_alpha = 0.5
        # IQL hyperparameter
        self.lam = config['lam']
        self.temp = config['temp']
        self.total_it = 0
        self.inverse_total_it = 0
        self.weight = 1

        # aka critic
        self.q_funcs = DoubleQFunc(config['state_dim'], config['action_dim'], hidden_size=config['hidden_sizes']).to(self.device)
        self.target_q_funcs = copy.deepcopy(self.q_funcs)
        self.target_q_funcs.eval()
        for p in self.target_q_funcs.parameters():
            p.requires_grad = False

        # aka value
        self.v_func = ValueFunc(config['state_dim'], config['action_dim'], hidden_size=config['hidden_sizes']).to(self.device)

        # aka actor
        self.policy = Policy(config['state_dim'], config['action_dim'], config['max_action'], hidden_size=config['hidden_sizes']).to(self.device)

        self.q_optimizer = torch.optim.Adam(self.q_funcs.parameters(), lr=config['critic_lr'])
        self.v_optimizer = torch.optim.Adam(self.v_func.parameters(), lr=config['critic_lr'])
        self.policy_optimizer = torch.optim.Adam(self.policy.parameters(), lr=config['actor_lr'])

        self.policy_lr_schedule = CosineAnnealingLR(self.policy_optimizer, config['max_step'])
    
        # inverse dynamics
        self.actor = Actor(config['state_dim'], config['action_dim'], config['max_action']).to(device)
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=3e-4)

        # aka state forward model
        self.model = Model(config['state_dim']).to(device)
        self.model_target = copy.deepcopy(self.model)
        self.model_optimizer = torch.optim.Adam(self.model.parameters(), lr=3e-4)
        
        # aka forward dynamics model
        self.forward_model = ForwardModel(config['state_dim'], config['action_dim']).to(device)
        self.forward_model_target = copy.deepcopy(self.forward_model)
        self.forward_model_optimizer = torch.optim.Adam(self.forward_model.parameters(), lr=3e-4)

        # aka reward model
        self.reward_model = RewardModel(config['state_dim'],config['action_dim']).to(self.device)  
        self.reward_model_opt = torch.optim.Adam(self.reward_model.parameters(),lr=3e-4)
    
    def select_action(self, state, test=True):
        with torch.no_grad():
            action = self.policy(torch.Tensor(state).view(1,-1).to(self.device))
        if test:
            return action.squeeze().cpu().numpy()
        else:
            return action.squeeze().cpu().numpy()



    def train_inverse_model(self, tar_replay_buffer, batch_size=256, writer=None):

        print('pretraining......')

        for i in trange(int(5e4)):
            self.inverse_total_it +=1 

            tar_state, tar_action, tar_next_state, tar_reward, tar_not_done = tar_replay_buffer.sample(batch_size)

            # train forward model:
            predicted_next_state = self.forward_model(tar_state, tar_action)
            delta = tar_next_state - tar_state
            forward_model_loss = F.mse_loss(delta, predicted_next_state)
            self.forward_model_optimizer.zero_grad()
            forward_model_loss.backward()
            self.forward_model_optimizer.step()

            # update inverse policy model
            inverse_action = self.actor(tar_state, tar_next_state)
            action_loss = F.mse_loss(inverse_action, tar_action)

            self.actor_optimizer.zero_grad()
            action_loss.backward()
            self.actor_optimizer.step()

     
            pred_reward = self.reward_model(tar_state, tar_action)
            reward_loss = ((pred_reward - tar_reward)**2).mean()
        
            self.reward_model_opt.zero_grad()
            reward_loss.backward()
            self.reward_model_opt.step()

            if self.inverse_total_it % 5000 == 0:
                writer.add_scalar('train/forward model loss', forward_model_loss, self.inverse_total_it+1)
                writer.add_scalar('train/Inverse Dynamics loss', action_loss, self.inverse_total_it+1)
                writer.add_scalar('train/Reward loss', reward_loss, self.inverse_total_it+1)
        
        print('end pretraining......')


    def update_target(self):
        """moving average update of target networks"""
        with torch.no_grad():
            for target_q_param, q_param in zip(self.target_q_funcs.parameters(), self.q_funcs.parameters()):
                target_q_param.data.copy_(self.tau * q_param.data + (1.0 - self.tau) * target_q_param.data)


    def update_q_functions(self, state_batch, action_batch, reward_batch, nextstate_batch, not_done_batch, writer=None):
        with torch.no_grad():
            nextaction_batch = self.policy(nextstate_batch)
            q_t1, q_t2 = self.target_q_funcs(nextstate_batch, nextaction_batch)
            # take min to mitigate positive bias in q-function training
            q_target = torch.min(q_t1, q_t2)
            value_target = reward_batch + not_done_batch * self.discount * q_target
        q_1, q_2 = self.q_funcs(state_batch, action_batch)


        if writer is not None and self.total_it % 5000 == 0:
            writer.add_scalar('train/q1', q_1.mean(), self.total_it)

        loss = F.mse_loss(q_1, value_target) +  F.mse_loss(q_2, value_target)
        return loss

    def bc_loss(self, true_state_batch, true_action_batch, writer = None):

        # BC loss
        pred_action = self.policy(true_state_batch)
        with torch.no_grad():
            q_b1, q_b2 = self.q_funcs(true_state_batch, true_action_batch)
            qval_batch = torch.min(q_b1, q_b2)


            adv = qval_batch 
            adv = adv/ adv.abs().mean()

        exp_adv = torch.exp(3 * adv).clamp(max=100.0)
        # exp_adv = exp_adv/ exp_adv.mean()
        bc_loss = (pred_action - true_action_batch)**2
        q_weighted = True
        if not q_weighted:
            exp_adv = 1

        # if self.total_it % 1000 == 0 and q_weighted:
        #     print(torch.mean(exp_adv),torch.min(exp_adv),torch.max(exp_adv))
        policy_loss = torch.mean(exp_adv * bc_loss)
        if writer is not None and self.total_it % 5000 == 0:
            writer.add_scalar('train/exp_adv', exp_adv.mean(), self.total_it)
            writer.add_scalar('train/bc_loss', policy_loss.mean(), self.total_it)
        
        return policy_loss

    def update_policy(self, state_batch, action_batch, writer = None):
        pred_action = self.policy(state_batch)
        q_b1, q_b2 = self.q_funcs(state_batch, pred_action)
        qval_batch = torch.min(q_b1, q_b2)
        weight = 1
        p_w = weight / qval_batch.abs().mean().detach()
       

        policy_loss = p_w * (- qval_batch).mean() 
        bc_coef = self.config['bc_coef']
        policy_loss += bc_coef * self.bc_loss(state_batch, action_batch)
        
        if writer is not None and self.total_it % 5000 == 0:
            with torch.no_grad():
                q_behavior1, q_behavior2 = self.q_funcs(state_batch, action_batch)
                q_behavior = torch.min(q_behavior1, q_behavior2)
            writer.add_scalar('train/q_behavior', q_behavior.mean(), self.total_it)
            writer.add_scalar('train/q_policy', qval_batch.mean(), self.total_it)
            writer.add_scalar('train/policy_loss', policy_loss, self.total_it)

        return policy_loss

    def train(self, src_replay_buffer, tar_replay_buffer, batch_size=128, writer=None):

        self.total_it += 1

        src_state, src_action, src_next_state, src_reward, src_not_done = src_replay_buffer.sample(batch_size)
        tar_state, tar_action, tar_next_state, tar_reward, tar_not_done = tar_replay_buffer.sample(batch_size)
        with torch.no_grad():
            pred_action = self.actor(src_state, src_next_state)

        # calculate action gradient
        grad_action = Variable(src_action.clone(), requires_grad=True)
        pred_r = self.reward_model(src_state, grad_action)
        reward_gradient = torch.autograd.grad(outputs=pred_r, inputs=grad_action,
                                              grad_outputs=torch.ones(pred_r.size(), device=self.device),
                                              retain_graph=True, create_graph=False,
                                              allow_unused=True)[0].flatten(start_dim=1)#.norm(dim=1, keepdim=True)
        
        reward_gradient_norm = reward_gradient.norm() + 1e-3
        reward_gradient /= reward_gradient_norm

        reward_gradient = reward_gradient.clamp(-50, 50)

        with torch.no_grad():
            correct_reward = src_reward + self.r_alpha * (reward_gradient*(pred_action-src_action)).sum(dim=1,keepdim=True)

            distances = ((pred_action - src_action)**2).mean(dim=-1)
            min_dist = distances.min()
            max_dist = distances.max()
            normalized_distances = (distances - min_dist) / (max_dist - min_dist)


            pred_a_next_state = src_state + self.forward_model(src_state, pred_action)
            pred_a_dis = ((pred_a_next_state - src_next_state)**2).mean(dim=-1)
            src_a_next_state = src_state + self.forward_model(src_state, src_action)
            src_a_dis = ((src_a_next_state - src_next_state)**2).mean(dim=-1)
            cond = (pred_a_dis < self.threshold*src_a_dis).unsqueeze(1) 
    
            sel_correct_action = torch.where(
                cond,      # condition: pred
                pred_action,                 # if True:  pred_action
                src_action                   # if False:  src_action
            )
            sel_correct_reward = torch.where(
                cond,      # condition: pred 
                correct_reward,                 # if True:  pred_action
                src_reward        # if False:  src_action
            )
        # filter out transitions
        src_filter_num = int(batch_size * self.config['proportion'])
        filter_cost, indices = torch.topk(-normalized_distances, src_filter_num)

        src_state = src_state[indices]
        sel_correct_action = sel_correct_action[indices]
        src_next_state = src_next_state[indices]

        sel_correct_reward = sel_correct_reward[indices]
        src_not_done = src_not_done[indices]
        src_cost = normalized_distances[indices]


        state = torch.cat([src_state, tar_state], 0)
        action = torch.cat([sel_correct_action, tar_action], 0)
        next_state = torch.cat([src_next_state, tar_next_state], 0)
        # reward = torch.cat([pred_r, tar_reward], 0)
        reward = torch.cat([sel_correct_reward, tar_reward], 0)
        not_done = torch.cat([src_not_done, tar_not_done], 0)

        self.weight = torch.ones_like(reward.flatten()).to(self.device)

        if self.config['weight']:
            # calculate cost weight
            cost_weight = torch.exp(-1.0*src_cost)
            self.weight[:src_state.shape[0]] = cost_weight

            self.weight = self.weight.unsqueeze(1)


        q_loss_step = self.update_q_functions(state, action, reward, next_state, not_done, writer)

        self.q_optimizer.zero_grad()
        q_loss_step.backward()
        self.q_optimizer.step()

        self.update_target()

        # update policy and temperature parameter
        for p in self.q_funcs.parameters():
            p.requires_grad = False
        pi_loss_step = self.update_policy(state, action,writer)
        self.policy_optimizer.zero_grad()
        pi_loss_step.backward()
        self.policy_optimizer.step()
        self.policy_lr_schedule.step()

        for p in self.q_funcs.parameters():
            p.requires_grad = True

    @property
    def alpha(self):
        return self.log_alpha.exp()

    def save(self, filename):
        torch.save(self.q_funcs.state_dict(), filename + "_critic")
        torch.save(self.q_optimizer.state_dict(), filename + "_critic_optimizer")
        torch.save(self.v_func.state_dict(), filename + "_value")
        torch.save(self.v_optimizer.state_dict(), filename + "_value_optimizer")
        torch.save(self.policy.state_dict(), filename + "_actor")
        torch.save(self.policy_optimizer.state_dict(), filename + "_actor_optimizer")
        torch.save(self.policy_lr_schedule.state_dict(), filename + "_actor_lr_scheduler")

    def load(self, filename):
        self.q_funcs.load_state_dict(torch.load(filename + "_critic"))
        self.q_optimizer.load_state_dict(torch.load(filename + "_critic_optimizer"))
        self.v_func.load_state_dict(torch.load(filename + "_value"))
        self.v_optimizer.load_state_dict(torch.load(filename + "_value_optimizer"))
        self.policy.load_state_dict(torch.load(filename + "_actor"))
        self.policy_optimizer.load_state_dict(torch.load(filename + "_actor_optimizer"))
        self.policy_lr_schedule.load_state_dict(torch.load(filename + "_actor_lr_scheduler"))