import hydra
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

import utils


class RandomShiftsAug(nn.Module):
    def __init__(self, pad):
        super().__init__()
        self.pad = pad

    def forward(self, x):
        n, c, h, w = x.size()
        assert h == w
        padding = tuple([self.pad] * 4)
        x = F.pad(x, padding, 'replicate')
        eps = 1.0 / (h + 2 * self.pad)
        arange = torch.linspace(-1.0 + eps,
                                1.0 - eps,
                                h + 2 * self.pad,
                                device=x.device,
                                dtype=x.dtype)[:h]
        arange = arange.unsqueeze(0).repeat(h, 1).unsqueeze(2)
        base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2)
        base_grid = base_grid.unsqueeze(0).repeat(n, 1, 1, 1)

        shift = torch.randint(0,
                              2 * self.pad + 1,
                              size=(n, 1, 1, 2),
                              device=x.device,
                              dtype=x.dtype)
        shift *= 2.0 / (h + 2 * self.pad)

        grid = base_grid + shift
        return F.grid_sample(x,
                             grid,
                             padding_mode='zeros',
                             align_corners=False)


class Encoder(nn.Module):
    def __init__(self, obs_shape):
        super().__init__()

        assert len(obs_shape) == 3
        self.repr_dim = 32 * 35 * 35

        self.convnet = nn.Sequential(nn.Conv2d(obs_shape[0], 32, 3, stride=2),
                                     nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1),
                                     nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1),
                                     nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1),
                                     nn.ReLU())

        self.apply(utils.weight_init)

    def forward(self, obs):
        obs = obs / 255.0 - 0.5
        h = self.convnet(obs)
        h = h.view(h.shape[0], -1)
        return h


class Actor(nn.Module):
    def __init__(self, repr_dim, action_shape, feature_dim, hidden_dim):
        super().__init__()

        self.trunk = nn.Sequential(nn.Linear(repr_dim, feature_dim),
                                   nn.LayerNorm(feature_dim), nn.Tanh())

        self.policy = nn.Sequential(nn.Linear(feature_dim, hidden_dim),
                                    nn.ReLU(inplace=True),
                                    nn.Linear(hidden_dim, hidden_dim),
                                    nn.ReLU(inplace=True),
                                    nn.Linear(hidden_dim, action_shape[0]))

        self.apply(utils.weight_init)

    def forward(self, obs, std):
        h = self.trunk(obs)

        mu = self.policy(h)
        mu = torch.tanh(mu)
        std = torch.ones_like(mu) * std

        dist = utils.TruncatedNormal(mu, std)
        return dist


class Critic(nn.Module):
    def __init__(self, repr_dim, action_shape, feature_dim, hidden_dim):
        super().__init__()
        input_dim = feature_dim + action_shape[0]
        input_dim += 1

        self.trunk = nn.Sequential(nn.Linear(repr_dim, feature_dim),
                                   nn.LayerNorm(feature_dim), nn.Tanh())

        self.Q1 = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, 1)
        )

        self.Q2 = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, 1)
        )

        self.apply(utils.weight_init)

    def forward(self, obs, action, gamma):
        h = self.trunk(obs)

        if not isinstance(gamma, torch.Tensor):
            gamma = torch.tensor([gamma], device=action.device, dtype=action.dtype)
            gamma = gamma.repeat(action.shape[0], 1)
        else:
            if gamma.dim() == 0:
                gamma = gamma.unsqueeze(0).unsqueeze(0).repeat(action.shape[0], 1)
            elif gamma.dim() == 1:
                gamma = gamma.unsqueeze(1)

        h_action = torch.cat([h, action, gamma], dim=-1)
        
        q1 = self.Q1(h_action)
        q2 = self.Q2(h_action)
        return q1, q2

class DrQV2Agent:
    def __init__(self, obs_shape, action_shape, device, lr, feature_dim,
                 hidden_dim, critic_target_tau, num_expl_steps,
                 update_every_steps, stddev_schedule, stddev_clip, use_tb, nstep, reward_length, aug_type, gammas):
        self.device = device
        self.critic_target_tau = critic_target_tau
        self.update_every_steps = update_every_steps
        self.use_tb = use_tb
        self.num_expl_steps = num_expl_steps
        self.stddev_schedule = stddev_schedule
        print(stddev_schedule)
        self.stddev_clip = stddev_clip

        self.gamma_min = 0
        self.gamma_max = 0.99
        self.num_gammas_per_update = 8
        self.n_step = nstep

        # models
        self.encoder = Encoder(obs_shape).to(device)
        self.actor = Actor(self.encoder.repr_dim, action_shape, feature_dim,
                           hidden_dim).to(device)

        self.critic = Critic(self.encoder.repr_dim, action_shape, feature_dim,
                             hidden_dim).to(device)
        self.critic_target = Critic(self.encoder.repr_dim, action_shape,
                                    feature_dim, hidden_dim).to(device)
        self.critic_target.load_state_dict(self.critic.state_dict())

        # optimizers
        self.encoder_opt = torch.optim.Adam(self.encoder.parameters(), lr=lr)
        self.actor_opt = torch.optim.Adam(self.actor.parameters(), lr=lr)
        self.critic_opt = torch.optim.Adam(self.critic.parameters(), lr=lr)

        # data augmentation
        self.aug = RandomShiftsAug(pad=4)

        self.train()
        self.critic_target.train()

    def train(self, training=True):
        self.training = training
        self.encoder.train(training)
        self.actor.train(training)
        self.critic.train(training)

    def act(self, obs, step, eval_mode):
        obs = torch.as_tensor(obs, device=self.device)
        obs = self.encoder(obs.unsqueeze(0))
        stddev = utils.schedule(self.stddev_schedule, step)
        dist = self.actor(obs, stddev)
        if eval_mode:
            action = dist.mean
        else:
            action = dist.sample(clip=None)
            if step < self.num_expl_steps:
                action.uniform_(-1.0, 1.0)
        return action.cpu().numpy()[0]

    def sample_gammas(self, B, num_g, gamma_min, gamma_max, device):
        gammas_base = torch.tensor([gamma_min, gamma_max], device=device).repeat(B, 1)  # [B, 2]
        
        num_remaining = num_g - 2
        half = num_remaining // 2
        
        gammas_gamma = torch.rand(B, half, device=device) * (gamma_max - gamma_min) + gamma_min
        taus = torch.rand(B, half, device=device) * (100 - 1) + 1
        gammas_tau = 1 - 1 / taus
        
        gammas = torch.cat([gammas_base, gammas_gamma, gammas_tau], dim=1)
        return gammas


    def update_critic(self, obs_enc, action, step_rewards, step_not_dones, next_obs_enc, step):
        metrics = dict()
        B = obs_enc.shape[0]
        num_g = self.num_gammas_per_update

        gammas = self.sample_gammas(B, num_g, self.gamma_min, self.gamma_max, self.device)

        obs_enc_expanded = obs_enc.unsqueeze(1).repeat(1, num_g, 1).view(-1, obs_enc.shape[-1])
        next_obs_enc_expanded = next_obs_enc.unsqueeze(1).repeat(1, num_g, 1).view(-1, next_obs_enc.shape[-1])
        action_expanded = action.unsqueeze(1).repeat(1, num_g, 1).view(-1, action.shape[-1])
        gammas_flat = gammas.view(-1, 1)

        with torch.no_grad():
            stddev = utils.schedule(self.stddev_schedule, step)
            dist = self.actor(next_obs_enc, stddev)
            next_action = dist.sample(clip=self.stddev_clip)
            next_action_expanded = next_action.unsqueeze(1).repeat(1, num_g, 1).view(-1, next_action.shape[-1])

            critic_total_reward = torch.zeros(B, num_g, 1, device=self.device)
            current_discount_weight = torch.ones(B, num_g, 1, device=self.device)
            
            for step_idx in range(self.n_step):
                step_reward = step_rewards[:, step_idx:step_idx+1].unsqueeze(1)  # [B, 1, 1] -> [B, num_g, 1]
                step_not_done = step_not_dones[:, step_idx:step_idx+1].unsqueeze(1)  # [B, 1, 1] -> [B, num_g, 1]
                critic_total_reward += step_reward * current_discount_weight
                current_discount_weight *= step_not_done * gammas.unsqueeze(-1)
            
            critic_total_discount = current_discount_weight
            target_q1, target_q2 = self.critic_target(next_obs_enc_expanded, next_action_expanded, gammas_flat)
            target_v = torch.min(target_q1, target_q2).view(B, num_g, 1)
            
            scale_factor = (1 - gammas).unsqueeze(-1)  # [B, num_g, 1]
            scaled_target_q = (critic_total_reward * scale_factor) + (critic_total_discount * target_v * scale_factor)
            scaled_target_q = scaled_target_q.view(-1, 1)  # [B*num_g, 1]

        current_q1, current_q2 = self.critic(obs_enc_expanded, action_expanded, gammas_flat)
        scale_factor_flat = scale_factor.view(-1, 1)  # [B*num_g, 1]
        scaled_current_q1 = current_q1 * scale_factor_flat
        scaled_current_q2 = current_q2 * scale_factor_flat

        critic_loss = F.mse_loss(scaled_current_q1, scaled_target_q) + F.mse_loss(scaled_current_q2, scaled_target_q)
        self.critic_opt.zero_grad(set_to_none=True)
        critic_loss.backward()
        self.critic_opt.step()

        if self.use_tb:
            metrics['critic_loss'] = critic_loss.item()
            metrics['avg_gamma'] = gammas.mean().item()
        return metrics

    def update_actor(self, obs, step):
        metrics = dict()

        stddev = utils.schedule(self.stddev_schedule, step)
        dist = self.actor(obs, stddev)
        action = dist.sample(clip=self.stddev_clip)
        log_prob = dist.log_prob(action).sum(-1, keepdim=True)
        Q1, Q2 = self.critic(obs, action, 0.9)
        Q = torch.min(Q1, Q2)

        actor_loss = -Q.mean()

        self.actor_opt.zero_grad(set_to_none=True)
        actor_loss.backward()
        self.actor_opt.step()

        if self.use_tb:
            metrics['actor_loss'] = actor_loss.item()
            metrics['actor_logprob'] = log_prob.mean().item()
            metrics['actor_ent'] = dist.entropy().sum(dim=-1).mean().item()
            metrics['avg_q_across_gammas'] = Q.mean().item()

        return metrics

    def update(self, replay_iter, step):
        metrics = dict()

        if step % self.update_every_steps != 0:
            return metrics

        batch = next(replay_iter)
        obs, action, step_rewards, step_not_dones, next_obs = utils.to_torch(batch, self.device)

        obs_aug = self.aug(obs.float())  # [B, C, H, W]
        next_obs_aug = self.aug(next_obs.float())  # [B, C, H, W]

        obs_enc = self.encoder(obs_aug)  # [B, repr_dim]
        with torch.no_grad():
            next_obs_enc = self.encoder(next_obs_aug)

        metrics.update(self.update_critic(
            obs_enc=obs_enc,
            action=action,
            step_rewards=step_rewards,
            step_not_dones=step_not_dones,
            next_obs_enc=next_obs_enc,
            step=step
        ))

        metrics.update(self.update_actor(obs_enc.detach(), step))

        utils.soft_update_params(
            self.critic,
            self.critic_target,
            self.critic_target_tau
        )

        return metrics