import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import utils

class Encoder_orginial(nn.Module):
    def __init__(self, obs_shape):
        super().__init__()

        assert len(obs_shape) == 3
        self.repr_dim = 133280
        self.convnet = nn.Sequential(nn.Conv2d(obs_shape[0], 32, 3, stride=2),
                                     nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1),
                                     nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1),
                                     nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1),
                                     nn.ReLU())

        self.apply(utils.weight_init)
    def forward(self, obs):
        obs = obs / 255.0 - 0.5
        h = self.convnet(obs)
        h = h.view(h.shape[0], -1)
        return h

class RewardDecoder(nn.Module):
    def __init__(self, num_critics, reward_length, hidden_units=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(num_critics, hidden_units),
            nn.ReLU(),
            nn.Linear(hidden_units, reward_length)
        )

    def forward(self, q_vector):
        return self.net(q_vector)

class CrossAttention(nn.Module):
    def __init__(self, query_dim, key_dim, hidden_dim, num_heads=4):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = hidden_dim // num_heads
        assert hidden_dim % num_heads == 0, "hidden_dim must be divisible by num_heads"
        
        self.q_proj = nn.Linear(query_dim, hidden_dim)
        self.k_proj = nn.Linear(key_dim, hidden_dim)
        self.v_proj = nn.Linear(key_dim, hidden_dim)
        
        self.out_proj = nn.Linear(hidden_dim, hidden_dim)
        self.layer_norm = nn.LayerNorm(hidden_dim)

    def forward(self, query, key, value):
        B, T, _ = key.shape
        
        Q = self.q_proj(query)  # [B, 1, hidden_dim]
        K = self.k_proj(key)    # [B, T, hidden_dim]
        V = self.v_proj(value)  # [B, T, hidden_dim]

        Q = Q.view(B, 1, self.num_heads, self.head_dim).transpose(1, 2)  # [B, num_heads, 1, head_dim]
        K = K.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)  # [B, num_heads, T, head_dim]
        V = V.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)  # [B, num_heads, T, head_dim]

        attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)  # [B, num_heads, 1, T]
        attention_weights = F.softmax(attention_scores, dim=-1)
        attention_output = torch.matmul(attention_weights, V)  # [B, num_heads, 1, head_dim]

        attention_output = attention_output.transpose(1, 2).contiguous().view(B, 1, self.num_heads * self.head_dim)  # [B, 1, hidden_dim]
        attention_output = self.out_proj(attention_output).squeeze(1)  # [B, hidden_dim]
        
        return attention_output, attention_weights.squeeze(2)

class QWeightPredictor(nn.Module):
    def __init__(self, critic_num, hidden_dim_predictor, repr_dim, key_dim, device, gamma_list, beta=2.0):
        super().__init__()
        self.critic_num = critic_num
        self.device = device
        self.gammas_tensor = torch.tensor(gamma_list, device=self.device, dtype=torch.float32)  # [critic_num]

        query_dim = repr_dim
        self.attention_hidden_dim = hidden_dim_predictor

        self.cross_attn = CrossAttention(
            query_dim=query_dim,
            key_dim=key_dim,
            hidden_dim=self.attention_hidden_dim,
            num_heads=4
        )

    def forward(self, obs_enc, V_stacked, critic_features):
        B, critic_num = V_stacked.shape
        assert critic_num == self.critic_num

        query = obs_enc.unsqueeze(1)  # [B, 1, repr_dim]
        V_stacked_reshaped = V_stacked.unsqueeze(-1)  # [B, critic_num, 1]
        gammas = self.gammas_tensor.view(1, critic_num, 1).repeat(B, 1, 1)  # [B, critic_num, 1]
        key_value = torch.cat([critic_features, V_stacked_reshaped, gammas], dim=-1)  # [B, critic_num, feature_dim+2]

        _, attn_weights = self.cross_attn(
            query=query,
            key=key_value,
            value=key_value
        )  # attn_weights: [B, num_heads, critic_num]

        learned_weights = attn_weights.mean(dim=1)  # [B, critic_num]
        learned_weights = F.softmax(learned_weights, dim=-1)

        return learned_weights  # [B, critic_num]


class RandomShiftsAug(nn.Module):
    def __init__(self, pad):
        super().__init__()
        self.pad = pad

    def forward(self, x):
        num_cameras = 3
        n, c, h, w = x.size()
        x = x.reshape(n,c,h,num_cameras,-1)
        x = x.permute(0,1,3,2,4)
        x = x.reshape(n,c*num_cameras,h,-1)
        n,_c,h,w = x.size()
        assert h == w
        padding = tuple([self.pad] * 4)
        x = F.pad(x, padding, 'replicate')
        eps = 1.0 / (h + 2 * self.pad)
        arange = torch.linspace(-1.0 + eps,
                                1.0 - eps,
                                h + 2 * self.pad,
                                device=x.device,
                                dtype=x.dtype)[:h]
        arange = arange.unsqueeze(0).repeat(h, 1).unsqueeze(2)
        base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2)
        base_grid = base_grid.unsqueeze(0).repeat(n, 1, 1, 1)

        shift = torch.randint(0,
                              2 * self.pad + 1,
                              size=(n, 1, 1, 2),
                              device=x.device,
                              dtype=x.dtype)
        shift *= 2.0 / (h + 2 * self.pad)

        grid = base_grid + shift
        sample_x = F.grid_sample(x,
                             grid,
                             padding_mode='zeros',
                             align_corners=False)
        sample_x = sample_x.reshape(n,c,num_cameras,h,-1)
        sample_x = sample_x.permute(0,1,3,2,4)
        sample_x = sample_x.reshape(n,c,h,-1)
        return sample_x

class NormalizeImg(nn.Module):
	def __init__(self):
		super().__init__()

	def forward(self, x):
		return x/255.

class CenterCrop(nn.Module):
	def __init__(self, size):
		super().__init__()
		assert size in {84, 100}, f'unexpected size: {size}'
		self.size = size

	def forward(self, x):
		assert x.ndim == 4, 'input must be a 4D tensor'
		if x.size(2) == self.size and x.size(3) == self.size:
			return x
		assert x.size(3) == 100, f'unexpected size: {x.size(3)}'
		if self.size == 84:
			p = 8
		return x[:, :, p:-p, p:-p]

def _get_out_shape(in_shape, layers):
	x = torch.randn(*in_shape).unsqueeze(0)
	return layers(x).squeeze(0).shape[0]

def weight_init(m):
	"""Custom weight init for Conv2D and Linear layers"""
	if isinstance(m, nn.Linear):
		nn.init.orthogonal_(m.weight.data)
		if hasattr(m.bias, 'data'):
			m.bias.data.fill_(0.0)
	elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
		assert m.weight.size(2) == m.weight.size(3)
		m.weight.data.fill_(0.0)
		if hasattr(m.bias, 'data'):
			m.bias.data.fill_(0.0)
		mid = m.weight.size(2) // 2
		gain = nn.init.calculate_gain('relu')
		nn.init.orthogonal_(m.weight.data[:, :, mid, mid], gain)

class Flatten(nn.Module):
	def __init__(self):
		super().__init__()
		
	def forward(self, x):
		return x.view(x.size(0), -1)

class Actor(nn.Module):
    def __init__(self, repr_dim, action_shape, feature_dim, hidden_dim):
        super().__init__()

        self.trunk = nn.Sequential(nn.Linear(repr_dim, feature_dim),
                                   nn.LayerNorm(feature_dim), nn.Tanh())

        self.policy = nn.Sequential(nn.Linear(feature_dim, hidden_dim),
                                    nn.ReLU(inplace=True),
                                    nn.Linear(hidden_dim, hidden_dim),
                                    nn.ReLU(inplace=True),
                                    nn.Linear(hidden_dim, action_shape[0]))

        self.apply(utils.weight_init)

    def forward(self, obs, std):
        h = self.trunk(obs)

        mu = self.policy(h)
        mu = torch.tanh(mu)
        std = torch.ones_like(mu) * std

        dist = utils.TruncatedNormal(mu, std)
        return dist


class Critic(nn.Module):
    def __init__(self, repr_dim, action_shape, feature_dim, hidden_dim):
        super().__init__()

        self.trunk = nn.Sequential(nn.Linear(repr_dim, feature_dim),
                                   nn.LayerNorm(feature_dim), nn.Tanh())

        self.Q1 = nn.Sequential(
            nn.Linear(feature_dim + action_shape[0], hidden_dim),
            nn.ReLU(inplace=True), nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(inplace=True), nn.Linear(hidden_dim, 1))

        self.Q2 = nn.Sequential(
            nn.Linear(feature_dim + action_shape[0], hidden_dim),
            nn.ReLU(inplace=True), nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(inplace=True), nn.Linear(hidden_dim, 1))

        self.apply(utils.weight_init)

    def forward(self, obs, action, return_feature=False):
        h = self.trunk(obs)  # [B, feature_dim]
        h_action = torch.cat([h, action], dim=-1)

        q1 = self.Q1(h_action)  # [B, 1]
        q2 = self.Q2(h_action)  # [B, 1]

        if return_feature:
            return q1, q2, h
        else:
            return q1, q2

class DrQV2Agent:
    def __init__(self, obs_shape, action_shape, device, lr, feature_dim, hidden_dim, critic_target_tau, num_expl_steps, update_every_steps, stddev_schedule, stddev_clip, use_tb
    , gammas, reward_length, nstep, aug_type):
        self.device = device
        self.critic_target_tau = critic_target_tau
        self.update_every_steps = update_every_steps
        self.use_tb = use_tb
        self.num_expl_steps = num_expl_steps
        self.stddev_schedule = stddev_schedule
        self.stddev_clip = stddev_clip

        self.n_step = nstep
        self.gammas = gammas
        self.critic_num = len(self.gammas)
        self.reward_length = nstep

        # models
        self.encoder = Encoder_orginial(obs_shape).to(device)
        self.actor = Actor(self.encoder.repr_dim, action_shape, feature_dim,
                           hidden_dim).to(device)

        self.critics = torch.nn.ModuleList([
            Critic(self.encoder.repr_dim, action_shape, feature_dim, hidden_dim).to(device)
            for _ in range(self.critic_num)
        ])
        self.critic_targets = torch.nn.ModuleList([
            Critic(self.encoder.repr_dim, action_shape, feature_dim, hidden_dim).to(device)
            for _ in range(self.critic_num)
        ])
        for i in range(self.critic_num):
            self.critic_targets[i].load_state_dict(self.critics[i].state_dict())

        self.q_weight_predictor = QWeightPredictor(
            critic_num=self.critic_num,
            hidden_dim_predictor=hidden_dim,
            repr_dim=self.encoder.repr_dim,
            key_dim=feature_dim + 2,  # feature_dim + V + gamma
            device=self.device,
            gamma_list=self.gammas
        ).to(device)

        self.reward_decoder = RewardDecoder(
            num_critics=self.critic_num,
            reward_length=self.reward_length
        ).to(device)

        self.L_matrix = torch.zeros((self.critic_num, self.reward_length), device=device)
        for i, gamma in enumerate(self.gammas):
            for j in range(self.reward_length):
                self.L_matrix[i, j] = gamma ** j


        U, S, Vh = torch.linalg.svd(self.L_matrix, full_matrices=False)
        alpha = 0.2
        S_inv = S / (S**2 + alpha**2)
        self.L_pinv = Vh.T @ torch.diag(S_inv) @ U.T

        # optimizers
        self.encoder_opt = torch.optim.Adam(self.encoder.parameters(), lr=lr)
        self.actor_opt = torch.optim.Adam(self.actor.parameters(), lr=lr)
        critic_params = []
        for critic in self.critics:
            critic_params.extend(critic.parameters())
        critic_params.extend(self.encoder.parameters())
        self.critic_opt = torch.optim.Adam(
            critic_params,
            lr=lr,
            betas=(0.9, 0.999)
        )

        self.q_weight_opt = torch.optim.Adam(
            self.q_weight_predictor.parameters(),
            lr=lr,
            betas=(0.9, 0.999)
        )
        self.decoder_opt = torch.optim.Adam(
            self.reward_decoder.parameters(),
            lr=lr,
            betas=(0.9, 0.999)
        )

        # data augmentation
        self.aug = RandomShiftsAug(pad=4)

        self.train()
        for target_critic in self.critic_targets:
            target_critic.train()

    def train(self, training=True):
        self.training = training
        self.encoder.train(training)
        self.actor.train(training)
        for critic in self.critics:
            critic.train(training)

    def select_action(self, obs,step):
        with torch.no_grad():
            obs = torch.FloatTensor(obs).to(self.device)
            obs = obs.unsqueeze(0)
            obs = self.encoder(obs)
            stddev = utils.schedule(self.stddev_schedule, step)
            dist = self.actor(obs,stddev)
            action = dist.mean
            return action.cpu().data.numpy().flatten()

    def sample_action(self, obs,step):
        with torch.no_grad():
            obs = torch.FloatTensor(obs).to(self.device)
            obs = obs.unsqueeze(0)
            obs = self.encoder(obs)
            stddev = utils.schedule(self.stddev_schedule, step)
            dist = self.actor(obs,stddev)

            action = dist.sample(clip=None)
            return action.cpu().data.numpy().flatten()

    def act(self, obs, step, eval_mode):
        if eval_mode:
            return self.select_action(obs, step)
        else:
            return self.sample_action(obs, step)


    def update_critic(self, obs_enc, action, step_rewards, step_not_dones, next_obs_enc, step):
        metrics = dict()
        total_critic_loss = 0.0
        q_list = []
        B = obs_enc.shape[0]

        step_rewards = step_rewards.view(B, self.n_step)  # [B, n_step]
        step_not_dones = step_not_dones.view(B, self.n_step)  # [B, n_step]

        first_step_reward = step_rewards[:, 0:1]  # [B, 1]

        with torch.no_grad():
            stddev = utils.schedule(self.stddev_schedule, step)
            dist = self.actor(next_obs_enc, stddev)
            next_action = dist.sample(clip=self.stddev_clip)

        for i in range(self.critic_num):
            with torch.no_grad():
                gamma_i = self.gammas[i]
                target_critic = self.critic_targets[i]

                critic_total_reward = torch.zeros(B, 1, device=self.device)
                current_discount_weight = torch.ones(B, 1, device=self.device)
                for step_idx in range(self.n_step):
                    step_contribution = step_rewards[:, step_idx:step_idx+1] * current_discount_weight
                    critic_total_reward += step_contribution
                    current_discount_weight *= step_not_dones[:, step_idx:step_idx+1] * gamma_i

                critic_total_discount = current_discount_weight
                target_q1, target_q2 = target_critic(next_obs_enc, next_action)
                target_v = torch.min(target_q1, target_q2)
                target_q = critic_total_reward + critic_total_discount * target_v

            current_q1, current_q2 = self.critics[i](obs_enc, action)
            current_q = torch.min(current_q1, current_q2)
            q_list.append(current_q)
            critic_i_loss = F.mse_loss(current_q1, target_q) + F.mse_loss(current_q2, target_q)
            total_critic_loss += critic_i_loss

            if self.use_tb:
                metrics[f'critic_{i}_q'] = current_q.mean().item()
                metrics[f'critic_{i}_total_reward'] = critic_total_reward.mean().item()
                metrics[f'critic_{i}_total_discount'] = critic_total_discount.mean().item()

        q_matrix = torch.cat(q_list, dim=1)
        pred_reward = self.reward_decoder(q_matrix)
        reward_loss = F.mse_loss(pred_reward[:, 0:1], first_step_reward)

        q_recon = torch.matmul(pred_reward, self.L_matrix.T)
        recon_loss = F.mse_loss(q_recon, q_matrix.detach())
        total_critic_loss += reward_loss + recon_loss

        self.critic_opt.zero_grad(set_to_none=True)
        self.decoder_opt.zero_grad(set_to_none=True)
        total_critic_loss.backward()
        self.critic_opt.step()
        self.decoder_opt.step()

        if self.use_tb:
            metrics['total_critic_loss'] = total_critic_loss.item()
            metrics['reward_loss'] = reward_loss.item()
            metrics['recon_loss'] = recon_loss.item()
            metrics['first_step_reward'] = first_step_reward.mean().item()

        return metrics

    def update_actor(self, obs_enc, step):
        metrics = dict()

        stddev = utils.schedule(self.stddev_schedule, step)
        dist = self.actor(obs_enc, stddev)
        action = dist.sample(clip=self.stddev_clip)  # [B, action_dim]
        log_prob = dist.log_prob(action).sum(-1, keepdim=True)  # [B, 1]

        q_list = []
        critic_features = []
        for i in range(self.critic_num):
            q1, q2, h = self.critics[i](obs_enc, action, return_feature=True)  # h: [B, feature_dim]
            q_min = torch.min(q1, q2)  # [B,1]
            q_list.append(q_min)
            critic_features.append(h.unsqueeze(1))

        q_stacked = torch.cat(q_list, dim=1).squeeze(-1)  # [B, critic_num]
        critic_features = torch.cat(critic_features, dim=1)  # [B, critic_num, feature_dim]

        learned_weights = self.q_weight_predictor(
            obs_enc=obs_enc,
            V_stacked=q_stacked,
            critic_features=critic_features
        )  # [B, critic_num]

        fused_q = torch.sum(q_stacked * learned_weights, dim=-1, keepdim=True)  # [B, 1]

        actor_loss = - (fused_q).mean()

        self.actor_opt.zero_grad(set_to_none=True)
        self.q_weight_opt.zero_grad(set_to_none=True)
        actor_loss.backward()

        self.actor_opt.step()
        self.q_weight_opt.step()

        if self.use_tb:
            metrics['actor_loss'] = actor_loss.item()
            metrics['actor_log_prob'] = log_prob.mean().item()
            metrics['fused_q'] = fused_q.mean().item()
            for i in range(self.critic_num):
                metrics[f'q_weight_{i}'] = learned_weights[:, i].mean().item()

        return metrics
    
    def update(self, replay_iter, step):
        metrics = dict()

        if step % self.update_every_steps != 0:
            return metrics

        batch = next(replay_iter)
        obs, action, step_rewards, step_not_dones, next_obs = utils.to_torch(batch, self.device)

        obs_aug = self.aug(obs.float())  # [B, C, H, W]
        next_obs_aug = self.aug(next_obs.float())  # [B, C, H, W]

        obs_enc = self.encoder(obs_aug)  # [B, repr_dim]
        with torch.no_grad():
            next_obs_enc = self.encoder(next_obs_aug)  # [B, repr_dim]

        metrics.update(self.update_critic(
            obs_enc=obs_enc,
            action=action,
            step_rewards=step_rewards,
            step_not_dones=step_not_dones,
            next_obs_enc=next_obs_enc,
            step=step
        ))

        metrics.update(self.update_actor(obs_enc.detach(), step))
        for i in range(self.critic_num):
            utils.soft_update_params(
                self.critics[i],
                self.critic_targets[i],
                self.critic_target_tau
            )

        return metrics
    
    def save(self, model_dir, step):
        torch.save(self.actor.state_dict(), f'{model_dir}/actor_{step}.pt')
        torch.save(self.encoder.state_dict(), f'{model_dir}/encoder_{step}.pt')
        torch.save(self.q_weight_predictor.state_dict(), f'{model_dir}/q_weight_predictor_{step}.pt')
        torch.save(self.reward_decoder.state_dict(), f'{model_dir}/reward_decoder_{step}.pt')

        for i in range(self.critic_num):
            torch.save(self.critics[i].state_dict(), f'{model_dir}/critic_{i}_{step}.pt')
            torch.save(self.critic_targets[i].state_dict(), f'{model_dir}/critic_target_{i}_{step}.pt')
        

    def load(self, model_dir, step):
        self.actor.load_state_dict(torch.load(f'{model_dir}/actor_{step}.pt', map_location=self.device))
        self.encoder.load_state_dict(torch.load(f'{model_dir}/encoder_{step}.pt', map_location=self.device))
        self.q_weight_predictor.load_state_dict(torch.load(f'{model_dir}/q_weight_predictor_{step}.pt', map_location=self.device))
        self.reward_decoder.load_state_dict(torch.load(f'{model_dir}/reward_decoder_{step}.pt', map_location=self.device))
        
        for i in range(self.critic_num):
            self.critics[i].load_state_dict(torch.load(f'{model_dir}/critic_{i}_{step}.pt', map_location=self.device))
            self.critic_targets[i].load_state_dict(torch.load(f'{model_dir}/critic_target_{i}_{step}.pt', map_location=self.device))
  