import hydra
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

import utils

class RewardDecoder(nn.Module):
    def __init__(self, num_critics, reward_length, hidden_units=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(num_critics, hidden_units),
            nn.ReLU(),
            nn.Linear(hidden_units, reward_length)
        )

    def forward(self, q_vector):
        return self.net(q_vector)


class HyperbolicWeightCalculator:
    def __init__(self, gamma_list, k=0.1, device='cuda'):
        self.gammas = torch.tensor(gamma_list, device=device, dtype=torch.float32)
        self.k = k
        self.device = device
        self.num_gammas = len(gamma_list)
        
        self.riemann_weights = self._compute_riemann_weights()
        
    def _compute_riemann_weights(self):
        weights = torch.zeros(self.num_gammas, device=self.device)
        
        for i in range(self.num_gammas):
            gamma_i = self.gammas[i]
            
            if i < self.num_gammas - 1:
                delta_gamma = self.gammas[i + 1] - gamma_i
            else:
                delta_gamma = self.gammas[i] - self.gammas[i - 1]
            
            w_gamma = (1.0 / self.k) * torch.pow(gamma_i, (1.0 / self.k) - 1)

            weights[i] = delta_gamma * w_gamma
        
        weights = weights / weights.sum()
        
        return weights
    
    def get_weights(self, batch_size=1):
        return self.riemann_weights.unsqueeze(0).expand(batch_size, -1)


class RandomShiftsAug(nn.Module):
    def __init__(self, pad):
        super().__init__()
        self.pad = pad

    def forward(self, x):
        n, c, h, w = x.size()
        assert h == w
        padding = tuple([self.pad] * 4)
        x = F.pad(x, padding, 'replicate')
        eps = 1.0 / (h + 2 * self.pad)
        arange = torch.linspace(-1.0 + eps,
                                1.0 - eps,
                                h + 2 * self.pad,
                                device=x.device,
                                dtype=x.dtype)[:h]
        arange = arange.unsqueeze(0).repeat(h, 1).unsqueeze(2)
        base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2)
        base_grid = base_grid.unsqueeze(0).repeat(n, 1, 1, 1)

        shift = torch.randint(0,
                              2 * self.pad + 1,
                              size=(n, 1, 1, 2),
                              device=x.device,
                              dtype=x.dtype)
        shift *= 2.0 / (h + 2 * self.pad)

        grid = base_grid + shift
        return F.grid_sample(x,
                             grid,
                             padding_mode='zeros',
                             align_corners=False)


class Encoder(nn.Module):
    def __init__(self, obs_shape):
        super().__init__()
        assert len(obs_shape) == 3, f"Expected (C,H,W), got {obs_shape}"

        C, H, W = obs_shape

        self.convnet = nn.Sequential(
            nn.Conv2d(C, 32, 3, stride=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, 3, stride=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, 3, stride=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, 3, stride=1),
            nn.ReLU(inplace=True)
        )

        with torch.no_grad():
            dummy_input = torch.zeros(1, *obs_shape)
            dummy_output = self.convnet(dummy_input)
            self.repr_dim = int(np.prod(dummy_output.shape[1:]))

        self.apply(utils.weight_init)

        print(f"[Encoder] Input: {obs_shape} → Output feature dim: {self.repr_dim}")

    def forward(self, obs):
        obs = obs / 255.0 - 0.5
        h = self.convnet(obs)
        h = h.view(h.size(0), -1)
        return h


class Actor(nn.Module):
    def __init__(self, repr_dim, action_shape, feature_dim, hidden_dim):
        super().__init__()

        self.trunk = nn.Sequential(nn.Linear(repr_dim, feature_dim),
                                   nn.LayerNorm(feature_dim), nn.Tanh())

        self.policy = nn.Sequential(nn.Linear(feature_dim, hidden_dim),
                                    nn.ReLU(inplace=True),
                                    nn.Linear(hidden_dim, hidden_dim),
                                    nn.ReLU(inplace=True),
                                    nn.Linear(hidden_dim, action_shape[0]))

        self.apply(utils.weight_init)

    def forward(self, obs, std):
        h = self.trunk(obs)
        mu = self.policy(h)
        mu = torch.tanh(mu)
        std = torch.ones_like(mu) * std
        dist = utils.TruncatedNormal(mu, std)
        return dist


class Critic(nn.Module):
    def __init__(self, repr_dim, action_shape, feature_dim, hidden_dim):
        super().__init__()
        self.action_dim = action_shape[0]

        self.trunk = nn.Sequential(nn.Linear(repr_dim, feature_dim),
                                   nn.LayerNorm(feature_dim), nn.Tanh())

        self.Q1 = nn.Sequential(
            nn.Linear(feature_dim + self.action_dim, hidden_dim),
            nn.ReLU(inplace=True), nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(inplace=True), nn.Linear(hidden_dim, 1))

        self.Q2 = nn.Sequential(
            nn.Linear(feature_dim + self.action_dim, hidden_dim),
            nn.ReLU(inplace=True), nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(inplace=True), nn.Linear(hidden_dim, 1))

        self.apply(utils.weight_init)

    def forward(self, obs, action, return_feature=False):
        h = self.trunk(obs)
        h_action = torch.cat([h, action], dim=-1)
        q1 = self.Q1(h_action)
        q2 = self.Q2(h_action)

        if return_feature:
            return q1, q2, h
        else:
            return q1, q2


class DrQV2Agent:
    def __init__(self, obs_shape, action_shape, device, lr, feature_dim,
                 hidden_dim, critic_target_tau, num_expl_steps,
                 update_every_steps, stddev_schedule, stddev_clip, use_tb, 
                 nstep, reward_length, aug_type, gammas, hyperbolic_k=0.1):
        # Basic parameters
        self.device = device
        self.critic_target_tau = critic_target_tau
        self.update_every_steps = update_every_steps
        self.use_tb = use_tb
        self.num_expl_steps = num_expl_steps
        self.stddev_schedule = stddev_schedule
        self.stddev_clip = stddev_clip
        self.action_shape = action_shape
        self.action_dim = action_shape[0]

        self.hyperbolic_k = hyperbolic_k
        self.num_gammas = 10
        self.gamma_max = 0.99

        self.n_step = nstep
        b = math.exp(math.log(1 - self.gamma_max ** (1 / hyperbolic_k)) / self.num_gammas)
        gammas_list = []
        for i in range(1, self.num_gammas + 1):
            gamma_i = (1 - b ** i) ** hyperbolic_k
            gammas_list.append(round(gamma_i, 4))
        self.gammas = tuple(gammas_list)
        self.critic_num = len(self.gammas)
        self.reward_length = reward_length
        self.encoder_feature_dim = feature_dim

        # Optimizer parameters
        self.actor_lr = 1e-4
        self.critic_lr = 1e-4
        self.decoder_lr = 1e-4
        self.actor_beta = 0.9
        self.critic_beta = 0.9

        # Initialize models
        self.encoder = Encoder(obs_shape).to(device)
        self.actor = Actor(self.encoder.repr_dim, action_shape, feature_dim, hidden_dim).to(device)

        # Multi-Critics
        self.critics = torch.nn.ModuleList([
            Critic(self.encoder.repr_dim, action_shape, self.encoder_feature_dim, hidden_dim).to(device)
            for _ in range(self.critic_num)
        ])
        self.critic_targets = torch.nn.ModuleList([
            Critic(self.encoder.repr_dim, action_shape, self.encoder_feature_dim, hidden_dim).to(device)
            for _ in range(self.critic_num)
        ])
        
        for i in range(self.critic_num):
            self.critic_targets[i].load_state_dict(self.critics[i].state_dict())

        self.hyperbolic_weight_calc = HyperbolicWeightCalculator(
            gamma_list=self.gammas,
            k=self.hyperbolic_k,
            device=self.device
        )
        
        print(f"[HyperbolicAgent] Using hyperbolic discounting with k={self.hyperbolic_k}")
        print(f"[HyperbolicAgent] Riemann weights: {self.hyperbolic_weight_calc.riemann_weights}")

        self.reward_decoder = RewardDecoder(
            num_critics=self.critic_num,
            reward_length=self.reward_length
        ).to(device)

        self.L_matrix = torch.zeros((self.critic_num, self.reward_length), device=device)
        for i, gamma in enumerate(self.gammas):
            for j in range(self.reward_length):
                self.L_matrix[i, j] = gamma ** j

        # Optimizers
        self.actor_opt = torch.optim.Adam(
            self.actor.parameters(),
            lr=self.actor_lr,
            betas=(self.actor_beta, 0.999)
        )

        critic_params = []
        for critic in self.critics:
            critic_params.extend(critic.parameters())
        critic_params.extend(self.encoder.parameters())
        self.critic_opt = torch.optim.Adam(
            critic_params,
            lr=self.critic_lr,
            betas=(self.critic_beta, 0.999)
        )

        self.decoder_opt = torch.optim.Adam(
            self.reward_decoder.parameters(),
            lr=self.decoder_lr,
            betas=(self.actor_beta, 0.999)
        )

        # Data augmentation
        self.aug = RandomShiftsAug(pad=4)

        self.train()
        for target_critic in self.critic_targets:
            target_critic.train()

    def train(self, training=True):
        self.training = training
        self.encoder.train(training)
        self.actor.train(training)
        for critic in self.critics:
            critic.train(training)

    def act(self, obs, step, eval_mode):
        obs = torch.as_tensor(obs, device=self.device)
        obs_enc = self.encoder(obs.unsqueeze(0))

        stddev = utils.schedule(self.stddev_schedule, step)
        dist = self.actor(obs_enc, stddev)
        if eval_mode:
            action = dist.mean
        else:
            action = dist.sample(clip=None)
            if step < self.num_expl_steps:
                action.uniform_(-1.0, 1.0)

        return action.cpu().numpy()[0]

    def update_critic(self, obs_enc, action, step_rewards, step_not_dones, next_obs_enc, step):
        metrics = dict()
        total_critic_loss = 0.0
        q_list = []
        B = obs_enc.shape[0]

        step_rewards = step_rewards.view(B, self.n_step)
        step_not_dones = step_not_dones.view(B, self.n_step)
        first_step_reward = step_rewards[:, 0:1]

        with torch.no_grad():
            stddev = utils.schedule(self.stddev_schedule, step)
            dist = self.actor(next_obs_enc, stddev)
            next_action = dist.sample(clip=self.stddev_clip)

        for i in range(self.critic_num):
            with torch.no_grad():
                gamma_i = self.gammas[i]
                target_critic = self.critic_targets[i]

                critic_total_reward = torch.zeros(B, 1, device=self.device)
                current_discount_weight = torch.ones(B, 1, device=self.device)
                for step_idx in range(self.n_step):
                    step_contribution = step_rewards[:, step_idx:step_idx+1] * current_discount_weight
                    critic_total_reward += step_contribution
                    current_discount_weight *= step_not_dones[:, step_idx:step_idx+1] * gamma_i

                critic_total_discount = current_discount_weight
                target_q1, target_q2 = target_critic(next_obs_enc, next_action)
                target_v = torch.min(target_q1, target_q2)
                target_q = critic_total_reward + critic_total_discount * target_v

            current_q1, current_q2 = self.critics[i](obs_enc, action)
            current_q = torch.min(current_q1, current_q2)
            q_list.append(current_q)
            critic_i_loss = F.mse_loss(current_q1, target_q) + F.mse_loss(current_q2, target_q)
            total_critic_loss += critic_i_loss

            if self.use_tb:
                metrics[f'critic_{i}_q'] = current_q.mean().item()

        # Reward reconstruction loss
        q_matrix = torch.cat(q_list, dim=1)
        pred_reward = torch.matmul(q_matrix, self.L_pinv.T)
        reward_loss = F.mse_loss(pred_reward[:, 0:1], first_step_reward)

        q_recon = torch.matmul(pred_reward, self.L_matrix.T)
        recon_loss = F.mse_loss(q_recon, q_matrix.detach())
        total_critic_loss += reward_loss + recon_loss

        self.critic_opt.zero_grad(set_to_none=True)
        self.decoder_opt.zero_grad(set_to_none=True)
        total_critic_loss.backward()
        self.critic_opt.step()
        self.decoder_opt.step()

        if self.use_tb:
            metrics['total_critic_loss'] = total_critic_loss.item()
            metrics['reward_loss'] = reward_loss.item()
            metrics['recon_loss'] = recon_loss.item()

        return metrics

    def update_actor(self, obs_enc, step):
        metrics = dict()

        # Sample action
        stddev = utils.schedule(self.stddev_schedule, step)
        dist = self.actor(obs_enc, stddev)
        action = dist.sample(clip=self.stddev_clip)
        log_prob = dist.log_prob(action).sum(-1, keepdim=True)

        q_list = []
        for i in range(self.critic_num):
            q1, q2 = self.critics[i](obs_enc, action, return_feature=False)
            q_min = torch.min(q1, q2)
            q_list.append(q_min)

        q_stacked = torch.cat(q_list, dim=1)  # [B, critic_num]

        hyperbolic_weights = self.hyperbolic_weight_calc.get_weights(
            batch_size=q_stacked.shape[0]
        )  # [B, critic_num]

        fused_q = torch.sum(q_stacked * hyperbolic_weights, dim=-1, keepdim=True)  # [B, 1]

        actor_loss = -fused_q.mean()

        self.actor_opt.zero_grad(set_to_none=True)
        actor_loss.backward()
        self.actor_opt.step()

        if self.use_tb:
            metrics['actor_loss'] = actor_loss.item()
            metrics['actor_log_prob'] = log_prob.mean().item()
            metrics['fused_q'] = fused_q.mean().item()
            for i in range(self.critic_num):
                metrics[f'hyperbolic_weight_{i}'] = hyperbolic_weights[0, i].item()

        return metrics

    def update(self, replay_iter, step):
        metrics = dict()

        if step % self.update_every_steps != 0:
            return metrics

        batch = next(replay_iter)
        obs, action, step_rewards, step_not_dones, next_obs = utils.to_torch(batch, self.device)

        obs_aug = self.aug(obs.float())
        next_obs_aug = self.aug(next_obs.float())

        obs_enc = self.encoder(obs_aug)
        with torch.no_grad():
            next_obs_enc = self.encoder(next_obs_aug)

        metrics.update(self.update_critic(
            obs_enc=obs_enc,
            action=action,
            step_rewards=step_rewards,
            step_not_dones=step_not_dones,
            next_obs_enc=next_obs_enc,
            step=step
        ))

        metrics.update(self.update_actor(obs_enc.detach(), step))

        for i in range(self.critic_num):
            utils.soft_update_params(
                self.critics[i],
                self.critic_targets[i],
                self.critic_target_tau
            )

        return metrics

    def save(self, model_dir, step):
        torch.save(self.actor.state_dict(), f'{model_dir}/actor_{step}.pt')
        torch.save(self.encoder.state_dict(), f'{model_dir}/encoder_{step}.pt')
        torch.save(self.reward_decoder.state_dict(), f'{model_dir}/reward_decoder_{step}.pt')

        for i in range(self.critic_num):
            torch.save(self.critics[i].state_dict(), f'{model_dir}/critic_{i}_{step}.pt')
            torch.save(self.critic_targets[i].state_dict(), f'{model_dir}/critic_target_{i}_{step}.pt')

    def load(self, model_dir, step):
        self.actor.load_state_dict(torch.load(f'{model_dir}/actor_{step}.pt', map_location=self.device))
        self.encoder.load_state_dict(torch.load(f'{model_dir}/encoder_{step}.pt', map_location=self.device))
        self.reward_decoder.load_state_dict(torch.load(f'{model_dir}/reward_decoder_{step}.pt', map_location=self.device))
        
        for i in range(self.critic_num):
            self.critics[i].load_state_dict(torch.load(f'{model_dir}/critic_{i}_{step}.pt', map_location=self.device))
            self.critic_targets[i].load_state_dict(torch.load(f'{model_dir}/critic_target_{i}_{step}.pt', map_location=self.device))