import torch
import torch.nn as nn
import torch.nn.functional as F
import sys
import numpy as np

sys.path.append("..")
import utils
from utils import ProbabilisticTransitionModel, StateRewardDecoder

EPSILON = 1e-4

def _sqrt(x, tol=0.):
    tol = torch.zeros_like(x)
    return torch.sqrt(torch.maximum(x, tol))

def cosine_distance(x, y):
    numerator = torch.sum(x * y, dim=-1, keepdim=True)
    # print("numerator", numerator.shape, numerator)
    denominator = torch.sqrt(
        torch.sum(x.pow(2.), dim=-1, keepdim=True)) * torch.sqrt(torch.sum(y.pow(2.), dim=-1, keepdim=True))
    cos_similarity = numerator / (denominator + EPSILON)

    return torch.atan2(_sqrt(1. - cos_similarity.pow(2.)), cos_similarity)

class RandomShiftsAug(nn.Module):
    def __init__(self, pad):
        super().__init__()
        self.pad = pad

    def forward(self, x):
        n, c, h, w = x.size()
        assert h == w
        padding = tuple([self.pad] * 4)
        x = F.pad(x, padding, 'replicate')
        eps = 1.0 / (h + 2 * self.pad)
        arange = torch.linspace(-1.0 + eps,
                                1.0 - eps,
                                h + 2 * self.pad,
                                device=x.device,
                                dtype=x.dtype)[:h]
        arange = arange.unsqueeze(0).repeat(h, 1).unsqueeze(2)
        base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2)
        base_grid = base_grid.unsqueeze(0).repeat(n, 1, 1, 1)

        shift = torch.randint(0,
                              2 * self.pad + 1,
                              size=(n, 1, 1, 2),
                              device=x.device,
                              dtype=x.dtype)
        shift *= 2.0 / (h + 2 * self.pad)

        grid = base_grid + shift
        return F.grid_sample(x,
                             grid,
                             padding_mode='zeros',
                             align_corners=False)


class Encoder(nn.Module):
    def __init__(self, obs_shape):
        super().__init__()

        assert len(obs_shape) == 3
        self.repr_dim = 32 * 35 * 35

        self.convnet = nn.Sequential(nn.Conv2d(obs_shape[0], 32, 3, stride=2),
                                     nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1),
                                     nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1),
                                     nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1),
                                     nn.ReLU())

        self.apply(utils.weight_init)

    def forward(self, obs):
        obs = obs / 255.0 - 0.5
        h = self.convnet(obs)
        h = h.view(h.shape[0], -1)
        return h


class Actor(nn.Module):
    def __init__(self, repr_dim, action_shape, feature_dim, hidden_dim):
        super().__init__()

        self.trunk = nn.Sequential(nn.Linear(repr_dim, feature_dim),
                                   nn.LayerNorm(feature_dim), nn.Tanh())

        self.policy = nn.Sequential(nn.Linear(feature_dim, hidden_dim),
                                    nn.ReLU(inplace=True),
                                    nn.Linear(hidden_dim, hidden_dim),
                                    nn.ReLU(inplace=True),
                                    nn.Linear(hidden_dim, action_shape[0]))

        self.apply(utils.weight_init)

    def forward(self, obs, std):
        h = self.trunk(obs)

        mu = self.policy(h)
        mu = torch.tanh(mu)
        std = torch.ones_like(mu) * std

        dist = utils.TruncatedNormal(mu, std)
        return dist


class Critic(nn.Module):
    def __init__(self, repr_dim, action_shape, feature_dim, hidden_dim):
        super().__init__()

        self.trunk = nn.Sequential(nn.Linear(repr_dim, feature_dim),
                                   nn.LayerNorm(feature_dim), nn.Tanh())

        self.Q1 = nn.Sequential(
            nn.Linear(feature_dim + action_shape[0], hidden_dim),
            nn.ReLU(inplace=True), nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(inplace=True), nn.Linear(hidden_dim, 1))

        self.Q2 = nn.Sequential(
            nn.Linear(feature_dim + action_shape[0], hidden_dim),
            nn.ReLU(inplace=True), nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(inplace=True), nn.Linear(hidden_dim, 1))

        self.apply(utils.weight_init)

    def forward(self, obs, action):
        h = self.trunk(obs)
        h_action = torch.cat([h, action], dim=-1)
        q1 = self.Q1(h_action)
        q2 = self.Q2(h_action)

        return q1, q2

class DrQV2MY3Agent:
    def __init__(self, obs_shape, action_shape, device, lr, feature_dim,target_lambda,
                 max_perturb_factor, hidden_dim, critic_target_tau, num_expl_steps, update_every_steps, stddev_schedule, stddev_clip, use_tb,
                 AvgStateR, PredictedDiff, StandardAug, int_weight, pre_train_step):
        self.device = device
        self.critic_target_tau = critic_target_tau
        self.update_every_steps = update_every_steps
        self.use_tb = use_tb
        self.num_expl_steps = num_expl_steps
        self.stddev_schedule = stddev_schedule
        self.stddev_clip = stddev_clip
        self.dormant_ratio = 1
        self.feature_dim = feature_dim
        self.AvgStateR = AvgStateR
        self.PredictedDiff = PredictedDiff
        self.StandardAug = StandardAug
        self.int_weight = int_weight
        self.pre_train_step = pre_train_step
        print("[TEA Setting]:", "V4:", not self.PredictedDiff, "V5:", self.PredictedDiff, "AvgStateR:", self.AvgStateR, "StandardAug:", self.StandardAug, "Int_w:", self.int_weight)
        # models
        self.encoder = Encoder(obs_shape).to(device)
        self.actor = Actor(self.encoder.repr_dim, action_shape, feature_dim,
                           hidden_dim).to(device)

        self.critic = Critic(self.encoder.repr_dim, action_shape, feature_dim,
                             hidden_dim).to(device)
        self.critic_target = Critic(self.encoder.repr_dim, action_shape,
                                    feature_dim, hidden_dim).to(device)
        self.critic_target.load_state_dict(self.critic.state_dict())

        # optimizers
        self.encoder_opt = torch.optim.Adam(self.encoder.parameters(), lr=lr)
        self.actor_opt = torch.optim.Adam(self.actor.parameters(), lr=lr)
        self.critic_opt = torch.optim.Adam(self.critic.parameters(), lr=lr)

        ########### BSM ##########
        self.transition_model = ProbabilisticTransitionModel(
            feature_dim, action_shape, layer_width=512
        ).to(device)
        self.reward_decoder = nn.Sequential(
            nn.Linear(feature_dim, 512),
            nn.LayerNorm(512),
            nn.ReLU(),
            nn.Linear(512, 1)).to(device)
        self.state_reward_decoder = StateRewardDecoder(
            feature_dim,action_shape).to(device)
        self.decoder_optimizer = torch.optim.Adam(
            list(self.reward_decoder.parameters()) + list(self.transition_model.parameters())
            + list(self.state_reward_decoder.parameters()),
            lr=lr, weight_decay=0.0
        )
        self.encoder_optimizer = torch.optim.Adam(
            list(self.encoder.parameters())+list(self.critic.trunk.parameters()), lr=lr)
        ########################## 

        # data augmentation
        self.aug = RandomShiftsAug(pad=4)

        self.train()
        self.critic_target.train()

    def train(self, training=True):
        self.training = training
        self.encoder.train(training)
        self.actor.train(training)
        self.critic.train(training)

    def act(self, obs, step, eval_mode):
        obs = torch.as_tensor(obs, device=self.device)
        obs = self.encoder(obs.unsqueeze(0))
        stddev = utils.schedule(self.stddev_schedule, step)
        dist = self.actor(obs, stddev)
        if eval_mode:
            action = dist.mean
        else:
            action = dist.sample(clip=None)
            if step < self.num_expl_steps:
                action.uniform_(-1.0, 1.0)
        return action.cpu().numpy()[0]


    def update_critic(self, obs, action, reward, discount, next_obs, step):
        metrics = dict()

        with torch.no_grad():
            stddev = utils.schedule(self.stddev_schedule, step)
            dist = self.actor(next_obs, stddev)
            next_action = dist.sample(clip=self.stddev_clip)
            target_Q1, target_Q2 = self.critic_target(next_obs, next_action)
            target_V = torch.min(target_Q1, target_Q2)
            target_Q = reward + (discount * target_V)

        Q1, Q2 = self.critic(obs, action)
        critic_loss = F.mse_loss(Q1, target_Q) + F.mse_loss(Q2, target_Q)

        if self.use_tb:
            metrics['critic_target_q'] = target_Q.mean().item()
            metrics['critic_q1'] = Q1.mean().item()
            metrics['critic_q2'] = Q2.mean().item()
            metrics['critic_loss'] = critic_loss.item()

        # optimize encoder and critic
        self.encoder_opt.zero_grad(set_to_none=True)
        self.critic_opt.zero_grad(set_to_none=True)
        critic_loss.backward()
        self.critic_opt.step()
        self.encoder_opt.step()

        return metrics

    def update_actor(self, obs, step):
        metrics = dict()

        stddev = utils.schedule(self.stddev_schedule, step)
        dist = self.actor(obs, stddev)
        action = dist.sample(clip=self.stddev_clip)
        log_prob = dist.log_prob(action).sum(-1, keepdim=True)
        Q1, Q2 = self.critic(obs, action)
        Q = torch.min(Q1, Q2)

        actor_loss = -Q.mean()

        # optimize actor
        self.actor_opt.zero_grad(set_to_none=True)
        actor_loss.backward()
        self.actor_opt.step()

        if self.use_tb:
            metrics['actor_loss'] = actor_loss.item()
            metrics['actor_logprob'] = log_prob.mean().item()
            metrics['actor_ent'] = dist.entropy().sum(dim=-1).mean().item()

        return metrics
    
    def aux_supervisor(self, h, action, reward, next_h, a_next, r_next, discount, step):
        metrics = dict()
        state_loss, reward_loss = self.update_transition_reward_model(h, action, next_h, reward)
        if self.PredictedDiff: # V5
            bsm_loss, prer_loss, self.ri, r_dist = self.update_encoder_v5(h, action, reward, next_h, a_next, r_next, discount, step)
        else:
            bsm_loss, prer_loss, self.ri, r_dist = self.update_encoder_v4(h, action, reward, next_h, a_next, r_next, discount, step)
        total_loss = 0.5 * (bsm_loss + prer_loss) + 1e-4 * (state_loss + reward_loss)

        self.encoder_optimizer.zero_grad()
        self.decoder_optimizer.zero_grad()
        total_loss.backward()
        self.encoder_optimizer.step()
        self.decoder_optimizer.step()
        if self.use_tb:
            metrics['intr_reward'] = self.ri.mean().item()
            metrics['r_dist'] = r_dist.mean().item()
            metrics['bsm_loss'] = bsm_loss.item()
            metrics['prer_loss'] = prer_loss.item()
            metrics['state_loss'] = state_loss.item()
            metrics['reward_loss'] = reward_loss.item()
        return metrics
    
    def update_transition_reward_model(self, h, action, next_h, reward):
        pred_next_latent_mu, pred_next_latent_sigma = self.transition_model(torch.cat([h, action], dim=1))
        if pred_next_latent_sigma is None:
            pred_next_latent_sigma = torch.ones_like(pred_next_latent_mu)

        diff = (pred_next_latent_mu - next_h.detach()) / pred_next_latent_sigma
        state_loss = torch.mean(0.5 * diff.pow(2) + torch.log(pred_next_latent_sigma))

        pred_next_latent = self.transition_model.sample_prediction(torch.cat([h, action], dim=1))
        pred_next_reward = self.reward_decoder(pred_next_latent)
        reward_loss = F.mse_loss(pred_next_reward, reward)

        return state_loss, reward_loss

    def update_encoder_v4(self, h, action, reward, next_h, a_next, r_next, discount, step): # V4
        batch_size = h.size(0)
        rep_size = h.size(1)
        perm = np.random.permutation(batch_size)
        h2 = h[perm]

        reward_mu, reward_sigma = self.state_reward_decoder(torch.cat([h, action], dim=1))
        loss_reward_decoder = self.state_reward_decoder.loss(
                reward_mu, reward_sigma, reward
            )
        with torch.no_grad():
            pred_next_latent_mu1, pred_next_latent_sigma1 = self.transition_model(torch.cat([h, action], dim=1))
            pred_next_latent_mu11, pred_next_latent_sigma11 = self.transition_model(torch.cat([next_h, a_next], dim=1))
            prer = self.state_reward_decoder.sample_prediction(torch.cat([h, action], dim=1))
            prer = prer.detach()
            next_prer = self.state_reward_decoder.sample_prediction(torch.cat([next_h, a_next], dim=1))
            next_prer = next_prer.detach()
            reward2 = reward[perm]
            prer2 = prer[perm]

        pred_next_latent_mu2 = pred_next_latent_mu1[perm] # B, 50
        pred_next_latent_sigma2 = pred_next_latent_sigma1[perm]

        z_dist = self.metric_func(h, h2, distance="l1_smooth")

        r_distt0 = F.smooth_l1_loss(reward, reward2, reduction='none')
        r_distt10 = F.smooth_l1_loss(r_next, reward2, reduction='none')

        transition_dist_loss = self.metric_func(pred_next_latent_mu1, pred_next_latent_mu2)

        if self.AvgStateR:
            r_distt0 = F.smooth_l1_loss(reward, reward2.mean(dim=0, keepdim=True), reduction='none')
            r_distt10 = F.smooth_l1_loss(r_next, reward2.mean(dim=0, keepdim=True), reduction='none')
            pred_next_latent_mu2 = pred_next_latent_mu2.mean(dim=0, keepdim=True).expand(batch_size, rep_size)
        else:
            r_distt0 = F.smooth_l1_loss(reward, reward2, reduction='none')
            r_distt10 = F.smooth_l1_loss(r_next, reward2, reduction='none')

        transition_dist = self.metric_func(pred_next_latent_mu1, pred_next_latent_mu2, distance="l1_smooth")
        transition_distn = self.metric_func(pred_next_latent_mu11, pred_next_latent_mu2, distance="l1_smooth")

        r_int =(discount * (discount * transition_distn + r_distt10) - (discount * transition_dist + r_distt0)).pow(2.)
        with torch.no_grad():
            r_dist_square = (prer - prer2).pow(2.) 
            r_dist = r_dist_square
        loss = F.smooth_l1_loss((z_dist - discount * transition_dist_loss).pow(2.), r_dist, reduction='mean')
        return loss, loss_reward_decoder, r_int.detach(), r_dist

    def update_encoder_v5(self, h, action, reward, next_h, a_next, r_next, discount, step): #v5
        batch_size = h.size(0)
        rep_size = h.size(1)

        perm = np.random.permutation(batch_size)
        h2 = h[perm]

        reward_mu, reward_sigma = self.state_reward_decoder(torch.cat([h, action], dim=1))
        loss_reward_decoder = self.state_reward_decoder.loss(
                reward_mu, reward_sigma, reward
            )
        with torch.no_grad():
            pred_next_latent_mu, pred_next_latent_sigma = self.transition_model(torch.cat([h, action], dim=1))
            pred_next_latent_mu1, pred_next_latent_sigma1 = self.transition_model(torch.cat([pred_next_latent_mu, a_next], dim=1)) # hat_P'
            
            pred_next_latent_mu2, pred_next_latent_sigma2 = self.transition_model(torch.cat([next_h, a_next], dim=1))

            prer = self.state_reward_decoder.sample_prediction(torch.cat([h, action], dim=1))
            prer = prer.detach()
            reward2 = reward[perm]

            r_next2 = r_next[perm]
            prer2 = prer[perm]
            prer_next = self.state_reward_decoder.sample_prediction(torch.cat([next_h, a_next], dim=1))
            prer_next = prer_next.detach()

        pred_next_latent_mup = pred_next_latent_mu[perm]
        pred_next_latent_mu22 = pred_next_latent_mu2[perm]

        z_dist = self.metric_func(h, h2, distance="l1_smooth")
        if self.AvgStateR:
            r_distt0 = F.smooth_l1_loss(prer, r_next2.mean().unsqueeze(-1), reduction='none')
            r_distt10 = F.smooth_l1_loss(r_next, r_next2.mean().unsqueeze(-1), reduction='none')
            pred_next_latent_mu22 = pred_next_latent_mu22.mean(dim=0, keepdim=True).expand(batch_size, rep_size)
        else:
            r_distt0 = F.smooth_l1_loss(prer_next, r_next2, reduction='none')
            r_distt10 = F.smooth_l1_loss(r_next, r_next2, reduction='none')

        transition_dist = self.metric_func(pred_next_latent_mu1, pred_next_latent_mu22, distance="l1_smooth")
        transition_distn = self.metric_func(pred_next_latent_mu2, pred_next_latent_mu22, distance="l1_smooth")

        r_int = (discount * (discount * transition_distn + r_distt10) - (discount * transition_dist + r_distt0)).pow(2.)
        with torch.no_grad():
            r_dist_square = (prer - prer2).pow(2.) 
            #r_dist_square = (reward - reward2).pow(2.) 
            r_dist = r_dist_square #- r_var - r_var2
        w_dist = self.metric_func(pred_next_latent_mu, pred_next_latent_mup, distance="l1_smooth")
        loss = F.smooth_l1_loss((z_dist - discount * w_dist).pow(2.), r_dist, reduction='mean')
        return loss, loss_reward_decoder, r_int.detach(), r_dist

    def metric_func(self, x, y, distance="mico_angular", opt=None):
        if distance == 'l2':
            dist = F.pairwise_distance(x, y, p=2, keepdim=True)
        elif distance == 'l1_smooth':
            dist = F.smooth_l1_loss(x, y, reduction='none')
            dist = dist.mean(dim=-1, keepdim=True)
        elif distance == 'mico_angular':
            beta = 1e-6
            base_distances = cosine_distance(x, y)
            norm_average = (x.pow(2.).sum(dim=-1, keepdim=True)
                + y.pow(2.).sum(dim=-1, keepdim=True))
            dist = norm_average + beta * base_distances
        elif distance == 'x^2+y^2-xy':
            k = 0.1 # 0 < k < 2
            base_distances = (x * y).sum(dim=-1, keepdim=True)
            norm_average = (x.pow(2.).sum(dim=-1, keepdim=True) 
                + y.pow(2.).sum(dim=-1, keepdim=True))
            dist = norm_average - k * base_distances
        else:
            raise NotImplementedError
        return dist

    def update(self, replay_iter, step):
        metrics = dict()

        ExtW = 1.0
        IntW = 1.0
        if step < self.pre_train_step:
            ExtW = 0.0
            IntW = self.int_weight * 10000000.0
        else:
            ExtW = 1.0
            IntW = self.int_weight    
        if step % self.update_every_steps != 0:
            return metrics
        
        if step % self.update_every_steps != 0:
            return metrics

        batch = next(replay_iter)
        obs, action, reward, discount, next_obs, reward1, o_next1, a_next1, r_next1, action3,_,_ = utils.to_torch(
            batch, self.device)

        # augment
        if not self.StandardAug:
            obs = self.aug(obs.float())
            next_obs = self.aug(next_obs.float())

            o_next1 = self.aug(o_next1.float())
        h = self.critic.trunk(self.encoder(obs))  
        next1_h = self.critic.trunk(self.encoder(o_next1))   
        metrics.update(self.aux_supervisor(h, action, reward1*ExtW, next1_h, a_next1, r_next1*ExtW, discount, step)) # 好一些

        # Aug or not
        if self.StandardAug:
            obs = self.aug(obs.float())
            next_obs = self.aug(next_obs.float())

        # encode
        obs = self.encoder(obs)
        with torch.no_grad():
            next_obs = self.encoder(next_obs)

        ri = self.ri * IntW
        #if step >= self.pre_train_step:
        ri = torch.clamp(ri, -0.1, 0.1)

        r_total = reward * ExtW + ri


        if self.use_tb:
            #metrics['beta(int)'] = beta
            metrics['batch_reward_real_use'] = (reward * ExtW).mean().item()
            metrics['sched_lin_intr_reward'] = (ri).mean().item()

            #metrics['raw_intr_reward'] = self.ri.mean().item()
            metrics['batch_reward'] = reward.mean().item()
            #metrics['actor_dormant_ratio'] = self.dormant_ratio

        # update critic
        metrics.update(
            self.update_critic(obs, action, r_total, discount, next_obs, step))

        # update actor
        metrics.update(self.update_actor(obs.detach(), step))

        # update critic target
        utils.soft_update_params(self.critic, self.critic_target,
                                 self.critic_target_tau)

        return metrics

class RunningNorm:
    def __init__(self, alpha=1e-4, eps=1e-8):
        self.mean = 0.0
        self.var = 1.0
        self.alpha = alpha
        self.eps = eps
        self.initialized = False

    def normalize(self, x: torch.Tensor):
        return (x - self.mean) / (self.var ** 0.5 + self.eps)

    def update(self, x: torch.Tensor):
        x = x.detach()
        batch_mean = x.mean().item()
        batch_var = x.var(unbiased=False).item()

        if not self.initialized:
            self.mean = batch_mean
            self.var = batch_var + self.eps
            self.initialized = True
        else:
            self.mean = (1 - self.alpha) * self.mean + self.alpha * batch_mean
            self.var  = (1 - self.alpha) * self.var  + self.alpha * batch_var

class ExponentialDecay:
    def __init__(self, beta0=0.2, beta_min=0.01, tau=1e6):
        """
        beta(t) = beta0 * exp(-t / tau), clipped to beta_min
        """
        self.beta0 = beta0
        self.beta_min = beta_min
        self.tau = tau

    def value(self, step):
        beta = self.beta0 * math.exp(-step / self.tau)
        return max(beta, self.beta_min)

import math

class LinearDecay:
    def __init__(self, beta0=0.005, beta_min=0.0001, tau=1e6): #0.00005 - 3.0e-5   //0.00034   dmc=0.05 0.005 mw=0.005 0.0001
        """
        beta(t) = max(beta_min, beta0 - (beta0 - beta_min) * (t / tau))
        """
        self.beta0 = beta0
        self.beta_min = beta_min
        self.tau = tau

    def value(self, step):

        fraction = min(float(step) / self.tau, 1.0)
        beta = self.beta0 - fraction * (self.beta0 - self.beta_min)
        return max(beta, self.beta_min)

