import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm

from Model import CrossAttentionTranformer as Backbone
from Codes import sign_to_bin, bin_to_sign

class DiffusionProcess:
    def __init__(self, n_steps, sigma, device):

        self.n_steps = n_steps
        self.device = device
        self.sigma = sigma
        
        betas = torch.linspace(1e-3, 1e-2, self.n_steps)
        betas = betas * 0 + sigma
        self.betas = betas.view(-1, 1).to(device)
        self.betas_bar = torch.cumsum(self.betas, 0).view(-1, 1)

    def q_sample(self, x0_sign, t1, t2):
        t1, t2 = t1.to(self.device).long(), t2.to(self.device).long()
        
        e = torch.randn_like(x0_sign)
        
        noise_factor_1 = torch.sqrt(self.betas_bar[t1]).view(-1, 1)
        noise_factor_2 = torch.sqrt(self.betas_bar[t2]).view(-1, 1)
        
        yt_1 = x0_sign + e * noise_factor_1
        yt_2 = x0_sign + e * noise_factor_2
        
        return yt_1, yt_2


class Consistency(nn.Module):
    def __init__(self, args, device):
        super().__init__()
        self.device = device
        self.args = args

        self.backbone = Backbone(args, device)
        
        self.diffusion = DiffusionProcess(args.N_steps, args.sigma, device)
        
        pc_matrix_T = args.code.pc_matrix.transpose(0, 1).float()
        self.register_buffer('pc_matrix_T', pc_matrix_T)

    def _calculate_syndrome_weight(self, y_sign):
        y_sign_clamped = torch.sign(y_sign)
        y_bin = sign_to_bin(y_sign_clamped)
        syndrome = torch.matmul(y_bin, self.pc_matrix_T) % 2
                
        return syndrome.sum(dim=-1).long()

    def calculate_syn_loss(self, x0_pred_logits):
        batch_size = x0_pred_logits.shape[0]
        num_checks = self.pc_matrix_T.shape[1]
        x0_pred_prob = torch.sigmoid(x0_pred_logits)
        soft_syndrome = torch.matmul(x0_pred_prob, self.pc_matrix_T)
        syn_loss_terms = (1.0 - torch.cos(torch.pi * soft_syndrome)) / 2.0
        normalized_loss = syn_loss_terms.sum() / (batch_size * num_checks)
        return normalized_loss

    def forward(self, y_t, t):
        y_t = y_t.to(self.device)
        t = t.to(self.device).long()
        error_pred_logits = self.backbone(y_t, t)
        return error_pred_logits

    def loss(self, x_truth, lambda_guidance=0):
        batch_size = x_truth.shape[0]
        x_truth = x_truth.to(self.device).float()
        x0_sign = bin_to_sign(x_truth).float()

        t1 = torch.randint(
            1, 
            self.diffusion.n_steps, 
            size=(batch_size,), 
            device=self.device
        )
        alpha = getattr(self.args, 'consistency_alpha', 0.5)
        t2 = t1 * alpha

        y_t1, y_t2 = self.diffusion.q_sample(x0_sign, t1, t2)
        
        t1_syn = self._calculate_syndrome_weight(y_t1)
        t2_syn = self._calculate_syndrome_weight(y_t2)

        error_pred_logits_1 = self.forward(y_t1, t1_syn)
        error_pred_logits_2 = self.forward(y_t2, t2_syn)

        error_target_1 = sign_to_bin(torch.sign(y_t1 * x0_sign))
        error_target_2 = sign_to_bin(torch.sign(y_t2 * x0_sign))
        
        distill_loss_1 = F.binary_cross_entropy_with_logits(error_pred_logits_1, error_target_1)
        distill_loss_2 = F.binary_cross_entropy_with_logits(error_pred_logits_2, error_target_2)
        
        total_distillation_loss = distill_loss_1 + distill_loss_2

        x0_pred_logits_1 = -error_pred_logits_1 * torch.sign(y_t1.detach())
        x0_pred_logits_2 = -error_pred_logits_2 * torch.sign(y_t2.detach())
        
        guidance_loss_1 = self.calculate_syn_loss(x0_pred_logits_1)
        guidance_loss_2 = self.calculate_syn_loss(x0_pred_logits_2)
        total_guidance_loss = guidance_loss_1 + guidance_loss_2
        
        total_loss = total_distillation_loss + lambda_guidance * total_guidance_loss
        
        return {
            'total_loss': total_loss,
            'distillation_loss': total_distillation_loss,
            'guidance_loss': total_guidance_loss,
        }

    @torch.no_grad()
    def decode(self, y_noisy):
        self.eval()
        y_noisy = y_noisy.to(self.device)
        
        time_tensor = self._calculate_syndrome_weight(y_noisy)
        
        error_pred_logits = self.forward(y_noisy, time_tensor)
        
        x0_pred_logits = -error_pred_logits * torch.sign(y_noisy)
        
        x_0_sign_approx = torch.sign(x0_pred_logits)
        x_pred_hard = sign_to_bin(x_0_sign_approx)
        
        return x_pred_hard