# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import hydra
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


from einops.layers.torch import Rearrange
from argparse import Namespace
from longformer import Longformer
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
import os
from captum.attr import GuidedBackprop, GuidedGradCam
from matplotlib.colors import LinearSegmentedColormap

import utils

class RandomShiftsAug(nn.Module):
    def __init__(self, pad):
        super().__init__()
        self.pad = pad

    def forward(self, x):
        n, c, h, w = x.size()
        assert h == w
        padding = tuple([self.pad] * 4)
        x = F.pad(x, padding, 'replicate')
        eps = 1.0 / (h + 2 * self.pad)
        arange = torch.linspace(-1.0 + eps,
                                1.0 - eps,
                                h + 2 * self.pad,
                                device=x.device,
                                dtype=x.dtype)[:h]
        arange = arange.unsqueeze(0).repeat(h, 1).unsqueeze(2)
        base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2)
        base_grid = base_grid.unsqueeze(0).repeat(n, 1, 1, 1)

        shift = torch.randint(0,
                              2 * self.pad + 1,
                              size=(n, 1, 1, 2),
                              device=x.device,
                              dtype=x.dtype)
        shift *= 2.0 / (h + 2 * self.pad)

        grid = base_grid + shift
        return F.grid_sample(x,
                             grid,
                             padding_mode='zeros',
                             align_corners=False)


class Encoder(nn.Module):
    def __init__(self, obs_shape):
        super().__init__()

        assert len(obs_shape) == 3
        self.repr_dim = 32 * 35 * 35

        self.convnet = nn.Sequential(nn.Conv2d(obs_shape[0], 32, 3, stride=2),
                                     nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1),
                                     nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1),
                                     nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1),
                                     nn.ReLU())

        self.apply(utils.weight_init)

    def forward(self, obs):
        obs = obs / 255.0 - 0.5
        h = self.convnet(obs)
        h = h.view(h.shape[0], -1)
        return h


class Actor(nn.Module):
    def __init__(self, repr_dim, action_shape, feature_dim, hidden_dim):
        super().__init__()

        self.trunk = nn.Sequential(nn.Linear(repr_dim, feature_dim),
                                   nn.LayerNorm(feature_dim), nn.Tanh())

        self.policy = nn.Sequential(nn.Linear(feature_dim, hidden_dim),
                                    nn.ReLU(inplace=True),
                                    nn.Linear(hidden_dim, hidden_dim),
                                    nn.ReLU(inplace=True),
                                    nn.Linear(hidden_dim, action_shape[0]))

        self.apply(utils.weight_init)

    def forward(self, obs, std):
        h = self.trunk(obs)

        mu = self.policy(h)
        mu = torch.tanh(mu)
        std = torch.ones_like(mu) * std

        dist = utils.TruncatedNormal(mu, std)
        return dist


class Critic(nn.Module):
    def __init__(self, repr_dim, action_shape, feature_dim, hidden_dim):
        super().__init__()

        self.trunk = nn.Sequential(nn.Linear(repr_dim, feature_dim),
                                   nn.LayerNorm(feature_dim), nn.Tanh())

        self.Q1 = nn.Sequential(
            nn.Linear(feature_dim + action_shape[0], hidden_dim),
            nn.ReLU(inplace=True), nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(inplace=True), nn.Linear(hidden_dim, 1))

        self.Q2 = nn.Sequential(
            nn.Linear(feature_dim + action_shape[0], hidden_dim),
            nn.ReLU(inplace=True), nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(inplace=True), nn.Linear(hidden_dim, 1))

        self.apply(utils.weight_init)

    def forward(self, obs, action):
        h = self.trunk(obs)
        h_action = torch.cat([h, action], dim=-1)
        q1 = self.Q1(h_action)
        q2 = self.Q2(h_action)

        return q1, q2


class ActionEncoding(nn.Module):
    def __init__(self, action_dim, latent_action_dim):
        super().__init__()
        self.action_dim = action_dim
        self.action_tokenizer = nn.Sequential(
            nn.Linear(action_dim, 64), nn.Tanh(),
            nn.Linear(64, latent_action_dim)
        )

        self.apply(utils.weight_init)

    def forward(self, action, seq=False):
        if seq:
            batch_size = action.shape[0]
            action = self.action_tokenizer(action)  # (batch_size, length_action_dim)
            action = action.reshape(batch_size, -1)
            return self.action_seq_tokenizer(action)
        else:
            return self.action_tokenizer(action)


class RCAgent:
    def __init__(self, obs_shape, action_shape, device, lr, feature_dim,
                 hidden_dim, critic_target_tau, num_expl_steps,
                 update_every_steps, stddev_schedule, stddev_clip, use_tb
                 ,temporal_type='longformer', temporal_args=None,
                 value_pred_lr=1e-4, frames=64):
        self.device = device
        self.critic_target_tau = critic_target_tau
        self.update_every_steps = update_every_steps
        self.use_tb = use_tb
        self.num_expl_steps = num_expl_steps
        self.stddev_schedule = stddev_schedule
        self.stddev_clip = stddev_clip

        self.log_interval = 100
        self.spatial2temporal = Rearrange('(b f) d -> b f d', f=64)
        self.feature_dim = feature_dim
        self.action_dim = action_shape[0]
        self.num_classes = 2000
        self.value_range = (0, 1000)
        self.action_emb = nn.Linear(action_shape[0], feature_dim).to(device)

        # 默认时序Transformer参数
        if temporal_args is None:
            temporal_args = {
                'dim': feature_dim + feature_dim,
                'depth': 3,
                'heads': 1,
                'dim_head': 64,
                'mlp_dim': 256,
                'attention_window': 32,
                'attention_mode': 'sliding_chunks',
                'dropout': 0.1,
                'emb_dropout': 0.1,
                'pool': 'cls'
            }

        temporal_args = Namespace(**temporal_args)

        assert temporal_type in ['longformer', 'linformer', 'transformer'], \
            "Only longformer, linformer, transformer are supported"

        if temporal_type == 'longformer':
            self.temporal_transformer = Longformer(
                seq_len=frames,
                dim=temporal_args.dim,
                depth=temporal_args.depth,
                heads=temporal_args.heads,
                dim_head=temporal_args.dim_head,
                mlp_dim=temporal_args.mlp_dim,
                attention_window=temporal_args.attention_window,
                attention_mode=temporal_args.attention_mode,
                dropout=temporal_args.dropout,
                emb_dropout=temporal_args.emb_dropout,
                pool=temporal_args.pool
            ).to(device)

        # models
        self.encoder = Encoder(obs_shape).to(device)
        self.actor = Actor(self.encoder.repr_dim, action_shape, feature_dim,
                           hidden_dim).to(device)

        self.critic = Critic(self.encoder.repr_dim, action_shape, feature_dim,
                             hidden_dim).to(device)
        self.critic_target = Critic(self.encoder.repr_dim, action_shape,
                                    feature_dim, hidden_dim).to(device)
        self.critic_target.load_state_dict(self.critic.state_dict())

        self.proj_s = nn.Sequential(nn.Linear(self.encoder.repr_dim, feature_dim),
                                    nn.LayerNorm(feature_dim), nn.Tanh()).to(device)

        self.mlphead_cls = nn.Sequential(
            nn.LayerNorm(temporal_args.dim),
            nn.Linear(temporal_args.dim, self.num_classes)
        ).to(device)

        self.mlphead_pre = nn.Sequential(
            nn.LayerNorm(temporal_args.dim),
            nn.Linear(temporal_args.dim, 1)
        ).to(device)

        self.mlphead_token = nn.Sequential(
            nn.LayerNorm(temporal_args.dim),
            nn.Linear(temporal_args.dim, 1)
        ).to(device)

        self.key = nn.Sequential(
            nn.Linear(feature_dim + action_shape[0], hidden_dim),
            nn.ReLU(inplace=True), nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(inplace=True), nn.Linear(hidden_dim, 1))

        # optimizers
        self.encoder_opt = torch.optim.Adam(self.encoder.parameters(), lr=lr)
        self.actor_opt = torch.optim.Adam(self.actor.parameters(), lr=lr)
        self.critic_opt = torch.optim.Adam(self.critic.parameters(), lr=lr)
        self.key_opt = torch.optim.Adam(self.key.parameters(), lr=lr)
        self.longformer_optimizer = torch.optim.Adam(
            list(self.temporal_transformer.parameters())
            + list(self.mlphead_pre.parameters())
            + list(self.mlphead_cls.parameters())
            + list(self.proj_s.parameters()),
            lr=value_pred_lr,
            weight_decay=1e-4
        )

        self.value_loss_cls = nn.CrossEntropyLoss()
        self.value_loss_pre = nn.MSELoss()
        self.key_pre = nn.MSELoss()
        self.cross_entropy_loss = nn.CrossEntropyLoss()

        # data augmentation
        self.aug = RandomShiftsAug(pad=4)

        self.train()
        self.critic_target.train()

    def train(self, training=True):
        self.training = training
        self.encoder.train(training)
        self.actor.train(training)
        self.critic.train(training)

    def act(self, obs, step, eval_mode):
        obs = torch.as_tensor(obs, device=self.device)
        obs = self.encoder(obs.unsqueeze(0))
        stddev = utils.schedule(self.stddev_schedule, step)
        dist = self.actor(obs, stddev)
        if eval_mode:
            action = dist.mean
        else:
            action = dist.sample(clip=None)
            if step < self.num_expl_steps:
                action.uniform_(-1.0, 1.0)
        return action.cpu().numpy()[0]

    def update_critic(self, obs, action, reward, discount, next_obs, step):
        metrics = dict()

        with torch.no_grad():
            stddev = utils.schedule(self.stddev_schedule, step)
            dist = self.actor(next_obs, stddev)
            next_action = dist.sample(clip=self.stddev_clip)
            target_Q1, target_Q2 = self.critic_target(next_obs, next_action)
            target_V = torch.min(target_Q1, target_Q2)
            target_Q = reward + (discount * target_V)

        Q1, Q2 = self.critic(obs, action)
        critic_loss = F.mse_loss(Q1, target_Q) + F.mse_loss(Q2, target_Q)

        if self.use_tb:
            metrics['critic_target_q'] = target_Q.mean().item()
            metrics['critic_q1'] = Q1.mean().item()
            metrics['critic_q2'] = Q2.mean().item()
            metrics['critic_loss'] = critic_loss.item()

        # optimize encoder and critic
        self.encoder_opt.zero_grad(set_to_none=True)
        self.critic_opt.zero_grad(set_to_none=True)
        critic_loss.backward()
        self.critic_opt.step()
        self.encoder_opt.step()

        return metrics

    def update_actor(self, obs, step):
        metrics = dict()

        stddev = utils.schedule(self.stddev_schedule, step)
        dist = self.actor(obs, stddev)
        action = dist.sample(clip=self.stddev_clip)
        log_prob = dist.log_prob(action).sum(-1, keepdim=True)
        Q1, Q2 = self.critic(obs, action)
        Q = torch.min(Q1, Q2)

        actor_loss = -Q.mean()

        # optimize actor
        self.actor_opt.zero_grad(set_to_none=True)
        actor_loss.backward()
        self.actor_opt.step()

        if self.use_tb:
            metrics['actor_loss'] = actor_loss.item()
            metrics['actor_logprob'] = log_prob.mean().item()
            metrics['actor_ent'] = dist.entropy().sum(dim=-1).mean().item()

        return metrics

    def update_RC(self, obs, action, episode_values, start_indices, step, next_obs, reward, discount):
        metrics = dict()
        episode_values = episode_values.float()
        obs = self.aug(obs.float())
        next_obs = self.aug(next_obs.float())

        class_labels = (episode_values * 2).clamp(0, self.num_classes - 1).long()
        x_enc = self.proj_s(self.encoder(obs))
        a_vis = self.action_emb(action)
        x = torch.cat([x_enc, a_vis], dim=1)
        h_action = torch.cat([x_enc, action], dim=-1)
        key = self.key(h_action)
        x = self.spatial2temporal(x)
        transformer_output = self.temporal_transformer(x)

        if isinstance(transformer_output, tuple) and len(transformer_output) > 1:
            x_cla, all_attentions  = transformer_output
            cls_attentions = all_attentions[:, :, 0, :]
            important_obs,important_reward, important_actions, important_discounts, important_next_obs \
                = self._extract_important_frames(obs, reward, action, next_obs, discount, cls_attentions, key)

        else:
            x_cla= transformer_output

        cls_logits = self.mlphead_cls(x_cla)
        pred_return = self.mlphead_pre(x_cla)

        cls_logits = cls_logits.squeeze(1)
        pred_return = pred_return.squeeze(-1)
        episode_values_normalized = episode_values / 1000.0

        loss_cls = self.value_loss_cls(cls_logits, class_labels)
        loss_pre = self.value_loss_pre(pred_return, episode_values_normalized)
        loss = 0.5*loss_pre + 0.5*loss_cls

        self.encoder_opt.zero_grad(set_to_none=True)
        self.longformer_optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.temporal_transformer.parameters(), max_norm=1.0)
        torch.nn.utils.clip_grad_norm_(self.mlphead_cls.parameters(), max_norm=1.0)
        torch.nn.utils.clip_grad_norm_(self.mlphead_pre.parameters(), max_norm=1.0)
        self.encoder_opt.step()
        self.longformer_optimizer.step()

        if step % self.log_interval == 0:

            pred_classes = torch.argmax(cls_logits, dim=1)
            accuracy = (pred_classes == class_labels).float().mean()

            print(f"Predicted classes: {pred_classes.tolist()}")
            print(f"True classes: {class_labels.tolist()}")
            print(f"Accuracy: {accuracy.item():.4f}")

        if self.use_tb:
            metrics['reward_loss']  = loss.item()

        frame_weights = cls_attentions[:, :, 64:64 + 64]
        frame_weights = frame_weights.reshape(-1, 1)
        key_loss = self.key_pre(key,frame_weights)
        self.key_opt.zero_grad(set_to_none=True)
        key_loss.backward()
        self.key_opt.step()

        xi_enc = self.encoder(important_obs)
        xin_enc = self.encoder(important_next_obs)
        with torch.no_grad():
            stddev = utils.schedule(self.stddev_schedule, step)
            dist = self.actor(xin_enc, stddev)
            next_action = dist.sample(clip=self.stddev_clip)
            target_Q1, target_Q2 = self.critic_target(xin_enc, next_action)
            target_V = torch.min(target_Q1, target_Q2)
            target_Q = important_reward + (important_discounts * target_V)

        Q1, Q2 = self.critic(xi_enc, important_actions)
        critic_loss = F.mse_loss(Q1, target_Q) + F.mse_loss(Q2, target_Q)

        self.encoder_opt.zero_grad(set_to_none=True)
        self.critic_opt.zero_grad(set_to_none=True)
        critic_loss.backward()
        self.critic_opt.step()
        self.encoder_opt.step()

        return metrics

    def update(self, replay_iter, aux_replay_iter, step):
        metrics = dict()

        if step % self.update_every_steps != 0:
            return metrics

        batch = next(replay_iter)
        obs, action, reward, discount, next_obs = utils.to_torch(
            batch, self.device)

        aux_batch = next(aux_replay_iter)
        aux_obs_batch, aux_action_batch, aux_episode_rewards, start_indices \
            , next_observations, rewards, discounts \
            = utils.to_torch(aux_batch, self.device)

        # augment
        obs = self.aug(obs.float())
        next_obs = self.aug(next_obs.float())
        obs = self.encoder(obs)

        with torch.no_grad():
            next_obs = self.encoder(next_obs)

        if self.use_tb:
            metrics['batch_reward'] = reward.mean().item()

        # update critic
        metrics.update(
            self.update_critic(obs, action, reward, discount, next_obs, step))

        # update actor
        metrics.update(self.update_actor(obs.detach(), step))

        # update critic target
        utils.soft_update_params(self.critic, self.critic_target,
                                 self.critic_target_tau)

        if step > 4000:
            metrics.update(
                self.update_RC(aux_obs_batch, aux_action_batch, aux_episode_rewards,
                                     start_indices, step, next_observations, rewards, discounts))

        return metrics