import hydra
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

import utils

class RewardDecoder(nn.Module):
    def __init__(self, num_critics, reward_length, hidden_units=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(num_critics, hidden_units),
            nn.ReLU(),
            nn.Linear(hidden_units, reward_length)
        )

    def forward(self, q_vector):
        return self.net(q_vector)

class CrossAttention(nn.Module):
    def __init__(self, query_dim, key_dim, hidden_dim, num_heads=4):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = hidden_dim // num_heads
        assert hidden_dim % num_heads == 0, "hidden_dim must be divisible by num_heads"
        
        self.q_proj = nn.Linear(query_dim, hidden_dim)
        self.k_proj = nn.Linear(key_dim, hidden_dim)
        self.v_proj = nn.Linear(key_dim, hidden_dim)
        
        self.out_proj = nn.Linear(hidden_dim, hidden_dim)
        self.layer_norm = nn.LayerNorm(hidden_dim)

    def forward(self, query, key, value):
        B, T, _ = key.shape
        
        Q = self.q_proj(query)  # [B, 1, hidden_dim]
        K = self.k_proj(key)    # [B, T, hidden_dim]
        V = self.v_proj(value)  # [B, T, hidden_dim]
        
        Q = Q.view(B, 1, self.num_heads, self.head_dim).transpose(1, 2)  # [B, num_heads, 1, head_dim]
        K = K.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)  # [B, num_heads, T, head_dim]
        V = V.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)  # [B, num_heads, T, head_dim]
        
        attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)  # [B, num_heads, 1, T]
        attention_weights = F.softmax(attention_scores, dim=-1)
        attention_output = torch.matmul(attention_weights, V)  # [B, num_heads, 1, head_dim]
        
        attention_output = attention_output.transpose(1, 2).contiguous().view(B, 1, self.num_heads * self.head_dim)  # [B, 1, hidden_dim]
        attention_output = self.out_proj(attention_output).squeeze(1)  # [B, hidden_dim]
        
        return attention_output, attention_weights.squeeze(2)

class QWeightPredictor(nn.Module):
    def __init__(self, critic_num, hidden_dim_predictor, repr_dim, key_dim, device, gamma_list, beta=2.0):
        super().__init__()
        self.critic_num = critic_num
        self.device = device
        self.gammas_tensor = torch.tensor(gamma_list, device=self.device, dtype=torch.float32)  # [critic_num]

        query_dim = repr_dim
        self.attention_hidden_dim = hidden_dim_predictor

        self.cross_attn = CrossAttention(
            query_dim=query_dim,
            key_dim=key_dim,
            hidden_dim=self.attention_hidden_dim,
            num_heads=4
        )

    def forward(self, obs_enc, V_stacked, critic_features):
        B, critic_num = V_stacked.shape
        assert critic_num == self.critic_num

        query = obs_enc.unsqueeze(1)  # [B, 1, repr_dim]
        V_stacked_reshaped = V_stacked.unsqueeze(-1)  # [B, critic_num, 1]
        gammas = self.gammas_tensor.view(1, critic_num, 1).repeat(B, 1, 1)  # [B, critic_num, 1]
        key_value = torch.cat([critic_features, V_stacked_reshaped, gammas], dim=-1)  # [B, critic_num, feature_dim+2]

        _, attn_weights = self.cross_attn(
            query=query,
            key=key_value,
            value=key_value
        )  # attn_weights: [B, num_heads, critic_num]

        learned_weights = attn_weights.mean(dim=1)  # [B, critic_num]
        learned_weights = F.softmax(learned_weights, dim=-1)

        return learned_weights  # [B, critic_num]


class RandomShiftsAug(nn.Module):
    def __init__(self, pad):
        super().__init__()
        self.pad = pad

    def forward(self, x):
        n, c, h, w = x.size()
        assert h == w
        padding = tuple([self.pad] * 4)
        x = F.pad(x, padding, 'replicate')
        eps = 1.0 / (h + 2 * self.pad)
        arange = torch.linspace(-1.0 + eps,
                                1.0 - eps,
                                h + 2 * self.pad,
                                device=x.device,
                                dtype=x.dtype)[:h]
        arange = arange.unsqueeze(0).repeat(h, 1).unsqueeze(2)
        base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2)
        base_grid = base_grid.unsqueeze(0).repeat(n, 1, 1, 1)

        shift = torch.randint(0,
                              2 * self.pad + 1,
                              size=(n, 1, 1, 2),
                              device=x.device,
                              dtype=x.dtype)
        shift *= 2.0 / (h + 2 * self.pad)

        grid = base_grid + shift
        return F.grid_sample(x,
                             grid,
                             padding_mode='zeros',
                             align_corners=False)


class Encoder(nn.Module):
    def __init__(self, obs_shape):
        super().__init__()

        assert len(obs_shape) == 3
        self.repr_dim = 32 * 35 * 35

        self.convnet = nn.Sequential(nn.Conv2d(obs_shape[0], 32, 3, stride=2),
                                     nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1),
                                     nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1),
                                     nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1),
                                     nn.ReLU())

        self.apply(utils.weight_init)

    def forward(self, obs):
        obs = obs / 255.0 - 0.5
        h = self.convnet(obs)
        h = h.view(h.shape[0], -1)
        return h


class Actor(nn.Module):
    def __init__(self, repr_dim, action_shape, feature_dim, hidden_dim):
        super().__init__()

        self.trunk = nn.Sequential(nn.Linear(repr_dim, feature_dim),
                                   nn.LayerNorm(feature_dim), nn.Tanh())

        self.policy = nn.Sequential(nn.Linear(feature_dim, hidden_dim),
                                    nn.ReLU(inplace=True),
                                    nn.Linear(hidden_dim, hidden_dim),
                                    nn.ReLU(inplace=True),
                                    nn.Linear(hidden_dim, action_shape[0]))

        self.apply(utils.weight_init)

    def forward(self, obs, std):
        h = self.trunk(obs)

        mu = self.policy(h)
        mu = torch.tanh(mu)
        std = torch.ones_like(mu) * std

        dist = utils.TruncatedNormal(mu, std)
        return dist


class Critic(nn.Module):
    def __init__(self, repr_dim, action_shape, feature_dim, hidden_dim):
        super().__init__()
        self.action_dim = action_shape[0]

        self.trunk = nn.Sequential(nn.Linear(repr_dim, feature_dim),
                                   nn.LayerNorm(feature_dim), nn.Tanh())

        self.Q1 = nn.Sequential(
            nn.Linear(feature_dim + self.action_dim, hidden_dim),
            nn.ReLU(inplace=True), nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(inplace=True), nn.Linear(hidden_dim, 1))

        self.Q2 = nn.Sequential(
            nn.Linear(feature_dim + self.action_dim, hidden_dim),
            nn.ReLU(inplace=True), nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(inplace=True), nn.Linear(hidden_dim, 1))

        self.apply(utils.weight_init)

    def forward(self, obs, action, return_feature=False):
        h = self.trunk(obs)  # [B, feature_dim]
        h_action = torch.cat([h, action], dim=-1)

        q1 = self.Q1(h_action)  # [B, 1]
        q2 = self.Q2(h_action)  # [B, 1]

        if return_feature:
            return q1, q2, h
        else:
            return q1, q2


class DrQV2Agent:
    def __init__(self, obs_shape, action_shape, device, lr, feature_dim,
                 hidden_dim, critic_target_tau, num_expl_steps,
                 update_every_steps, stddev_schedule, stddev_clip, use_tb, nstep, reward_length, aug_type, gammas):
        self.device = device
        self.critic_target_tau = critic_target_tau
        self.update_every_steps = update_every_steps
        self.use_tb = use_tb
        self.num_expl_steps = num_expl_steps
        self.stddev_schedule = stddev_schedule
        self.stddev_clip = stddev_clip
        self.action_shape = action_shape 
        self.action_dim = action_shape[0]

        self.n_step = nstep
        self.gammas = gammas
        self.critic_num = len(self.gammas)
        self.reward_length = reward_length
        self.encoder_feature_dim = feature_dim

        self.actor_lr = 1e-4
        self.critic_lr = 1e-4
        self.q_weight_lr = 1e-4
        self.decoder_lr = 1e-4
        self.actor_beta = 0.9
        self.critic_beta = 0.9

        self.encoder = Encoder(obs_shape).to(device)
        self.actor = Actor(self.encoder.repr_dim, action_shape, feature_dim, hidden_dim).to(device)
        self.critics = torch.nn.ModuleList([
            Critic(self.encoder.repr_dim, action_shape, self.encoder_feature_dim, hidden_dim).to(device)
            for _ in range(self.critic_num)
        ])
        self.critic_targets = torch.nn.ModuleList([
            Critic(self.encoder.repr_dim, action_shape, self.encoder_feature_dim, hidden_dim).to(device)
            for _ in range(self.critic_num)
        ])
        for i in range(self.critic_num):
            self.critic_targets[i].load_state_dict(self.critics[i].state_dict())

        self.q_weight_predictor = QWeightPredictor(
            critic_num=self.critic_num,
            hidden_dim_predictor=hidden_dim,
            repr_dim=self.encoder.repr_dim,
            key_dim=self.encoder_feature_dim + 2,  # feature_dim + V + gamma
            device=self.device,
            gamma_list=self.gammas
        ).to(device)

        self.reward_decoder = RewardDecoder(
            num_critics=self.critic_num,
            reward_length=self.reward_length
        ).to(device)

        self.L_matrix = torch.zeros((self.critic_num, self.reward_length), device=device)
        for i, gamma in enumerate(self.gammas):
            for j in range(self.reward_length):
                self.L_matrix[i, j] = gamma ** j  # L[i,j] = gamma_i^j

        U, S, Vh = torch.linalg.svd(self.L_matrix, full_matrices=False)
        alpha = 0.2
        S_inv = S / (S**2 + alpha**2)
        self.L_pinv = Vh.T @ torch.diag(S_inv) @ U.T

        self.actor_opt = torch.optim.Adam(
            self.actor.parameters(),
            lr=self.actor_lr,
            betas=(self.actor_beta, 0.999)
        )

        critic_params = []
        for critic in self.critics:
            critic_params.extend(critic.parameters())
        critic_params.extend(self.encoder.parameters())
        self.critic_opt = torch.optim.Adam(
            critic_params,
            lr=self.critic_lr,
            betas=(self.critic_beta, 0.999)
        )

        self.q_weight_opt = torch.optim.Adam(
            self.q_weight_predictor.parameters(),
            lr=self.q_weight_lr,
            betas=(self.actor_beta, 0.999)
        )
        self.decoder_opt = torch.optim.Adam(
            self.reward_decoder.parameters(),
            lr=self.decoder_lr,
            betas=(self.actor_beta, 0.999)
        )

        self.aug = RandomShiftsAug(pad=4)

        self.train()
        for target_critic in self.critic_targets:
            target_critic.train()

    def train(self, training=True):
        self.training = training
        self.encoder.train(training)
        self.actor.train(training)
        for critic in self.critics:
            critic.train(training)

    def act(self, obs, step, eval_mode):
        obs = torch.as_tensor(obs, device=self.device)
        obs_enc = self.encoder(obs.unsqueeze(0))  # [1, repr_dim]

        stddev = utils.schedule(self.stddev_schedule, step)
        dist = self.actor(obs_enc, stddev)
        if eval_mode:
            action = dist.mean
        else:
            action = dist.sample(clip=None)
            if step < self.num_expl_steps:
                action.uniform_(-1.0, 1.0)

        return action.cpu().numpy()[0]

    def update_critic(self, obs_enc, action, step_rewards, step_not_dones, next_obs_enc, step):
        metrics = dict()
        total_critic_loss = 0.0
        q_list = []
        B = obs_enc.shape[0]

        step_rewards = step_rewards.view(B, self.n_step)  # [B, n_step]
        step_not_dones = step_not_dones.view(B, self.n_step)  # [B, n_step]
        
        first_step_reward = step_rewards[:, 0:1]

        with torch.no_grad():
            stddev = utils.schedule(self.stddev_schedule, step)
            dist = self.actor(next_obs_enc, stddev)
            next_action = dist.sample(clip=self.stddev_clip)

        for i in range(self.critic_num):
            with torch.no_grad():
                gamma_i = self.gammas[i]
                target_critic = self.critic_targets[i]

                critic_total_reward = torch.zeros(B, 1, device=self.device)
                current_discount_weight = torch.ones(B, 1, device=self.device)
                for step_idx in range(self.n_step):
                    step_contribution = step_rewards[:, step_idx:step_idx+1] * current_discount_weight
                    critic_total_reward += step_contribution
                    current_discount_weight *= step_not_dones[:, step_idx:step_idx+1] * gamma_i

                critic_total_discount = current_discount_weight
                target_q1, target_q2 = target_critic(next_obs_enc, next_action)
                target_v = torch.min(target_q1, target_q2)
                target_q = critic_total_reward + critic_total_discount * target_v

            current_q1, current_q2 = self.critics[i](obs_enc, action)
            current_q = torch.min(current_q1, current_q2)
            q_list.append(current_q)
            critic_i_loss = F.mse_loss(current_q1, target_q) + F.mse_loss(current_q2, target_q)
            total_critic_loss += critic_i_loss

            if self.use_tb:
                metrics[f'critic_{i}_q'] = current_q.mean().item()
                metrics[f'critic_{i}_total_reward'] = critic_total_reward.mean().item()
                metrics[f'critic_{i}_total_discount'] = critic_total_discount.mean().item()

        q_matrix = torch.cat(q_list, dim=1)
        pred_reward = self.reward_decoder(q_matrix)
        reward_loss = F.mse_loss(pred_reward[:, 0:1], first_step_reward)

        q_recon = torch.matmul(pred_reward, self.L_matrix.T)
        recon_loss = F.mse_loss(q_recon, q_matrix.detach())
        total_critic_loss += reward_loss + recon_loss

        self.critic_opt.zero_grad(set_to_none=True)
        self.decoder_opt.zero_grad(set_to_none=True)
        total_critic_loss.backward()
        self.critic_opt.step()
        self.decoder_opt.step()

        if self.use_tb:
            metrics['total_critic_loss'] = total_critic_loss.item()
            metrics['reward_loss'] = reward_loss.item()
            metrics['recon_loss'] = recon_loss.item()
            metrics['first_step_reward'] = first_step_reward.mean().item()

        return metrics

    def update_actor(self, obs_enc, step):
        metrics = dict()
        stddev = utils.schedule(self.stddev_schedule, step)
        dist = self.actor(obs_enc, stddev)
        action = dist.sample(clip=self.stddev_clip)  # [B, action_dim]
        log_prob = dist.log_prob(action).sum(-1, keepdim=True)  # [B, 1]

        q_list = []
        critic_features = []
        for i in range(self.critic_num):
            q1, q2, h = self.critics[i](obs_enc, action, return_feature=True)  # h: [B, feature_dim]
            q_min = torch.min(q1, q2)  # [B,1]
            q_list.append(q_min)
            critic_features.append(h.unsqueeze(1))

        q_stacked = torch.cat(q_list, dim=1).squeeze(-1)  # [B, critic_num]
        critic_features = torch.cat(critic_features, dim=1)  # [B, critic_num, feature_dim]

        learned_weights = self.q_weight_predictor(
            obs_enc=obs_enc,
            V_stacked=q_stacked,
            critic_features=critic_features
        )  # [B, critic_num]

        fused_q = torch.sum(q_stacked * learned_weights, dim=-1, keepdim=True)  # [B, 1]

        actor_loss = - (fused_q).mean()

        self.actor_opt.zero_grad(set_to_none=True)
        self.q_weight_opt.zero_grad(set_to_none=True)
        actor_loss.backward()

        self.actor_opt.step()
        self.q_weight_opt.step()

        if self.use_tb:
            metrics['actor_loss'] = actor_loss.item()
            metrics['actor_log_prob'] = log_prob.mean().item()
            metrics['fused_q'] = fused_q.mean().item()
            for i in range(self.critic_num):
                metrics[f'q_weight_{i}'] = learned_weights[:, i].mean().item()

        return metrics

    def update(self, replay_iter, step):
        metrics = dict()

        if step % self.update_every_steps != 0:
            return metrics

        batch = next(replay_iter)
        obs, action, step_rewards, step_not_dones, next_obs = utils.to_torch(batch, self.device)

        obs_aug = self.aug(obs.float())  # [B, C, H, W]
        next_obs_aug = self.aug(next_obs.float())  # [B, C, H, W]

        obs_enc = self.encoder(obs_aug)  # [B, repr_dim]
        with torch.no_grad():
            next_obs_enc = self.encoder(next_obs_aug)  # [B, repr_dim]

        metrics.update(self.update_critic(
            obs_enc=obs_enc,
            action=action,
            step_rewards=step_rewards,
            step_not_dones=step_not_dones,
            next_obs_enc=next_obs_enc,
            step=step
        ))

        metrics.update(self.update_actor(obs_enc.detach(), step))

        for i in range(self.critic_num):
            utils.soft_update_params(
                self.critics[i],
                self.critic_targets[i],
                self.critic_target_tau
            )

        return metrics

    def save(self, model_dir, step):
        torch.save(self.actor.state_dict(), f'{model_dir}/actor_{step}.pt')
        torch.save(self.encoder.state_dict(), f'{model_dir}/encoder_{step}.pt')
        torch.save(self.q_weight_predictor.state_dict(), f'{model_dir}/q_weight_predictor_{step}.pt')
        torch.save(self.reward_decoder.state_dict(), f'{model_dir}/reward_decoder_{step}.pt')

        for i in range(self.critic_num):
            torch.save(self.critics[i].state_dict(), f'{model_dir}/critic_{i}_{step}.pt')
            torch.save(self.critic_targets[i].state_dict(), f'{model_dir}/critic_target_{i}_{step}.pt')
        

    def load(self, model_dir, step):
        self.actor.load_state_dict(torch.load(f'{model_dir}/actor_{step}.pt', map_location=self.device))
        self.encoder.load_state_dict(torch.load(f'{model_dir}/encoder_{step}.pt', map_location=self.device))
        self.q_weight_predictor.load_state_dict(torch.load(f'{model_dir}/q_weight_predictor_{step}.pt', map_location=self.device))
        self.reward_decoder.load_state_dict(torch.load(f'{model_dir}/reward_decoder_{step}.pt', map_location=self.device))
        
        for i in range(self.critic_num):
            self.critics[i].load_state_dict(torch.load(f'{model_dir}/critic_{i}_{step}.pt', map_location=self.device))
            self.critic_targets[i].load_state_dict(torch.load(f'{model_dir}/critic_target_{i}_{step}.pt', map_location=self.device))