from typing import Dict, Tuple
import contextlib
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel as DDP
from torch import amp
from .EncoderNoiseDecoder import *


class Network:

    def __init__(
        self,
        noise_layers,
        device: torch.device,
        lr: float,
        accum_steps: int = 1,
        use_ddp: bool = False,
        use_amp: bool = False,
    ):
        self.device = device
        self.use_ddp = use_ddp
        self.use_amp = use_amp

        base = EncoderDecoder(noise_layers).to(device)

        if self.use_ddp:
            self.encoder_decoder = DDP(
                base,
                device_ids=[device.index],
                output_device=device.index,
                find_unused_parameters=False,  
                static_graph=True,           
            )
        else:
            self.encoder_decoder = base

        self.opt_encoder_decoder = torch.optim.Adam(
            filter(lambda p: p.requires_grad, self.encoder_decoder.parameters()), lr=lr
        )

        self.criterion_img = nn.MSELoss().to(device)   
        self.criterion_msg = nn.MSELoss().to(device)   

        self.encoder_weight = 1.0
        self.decoder_weight = 100.0

        self.accum_steps = max(1, int(accum_steps))
        self._accum_counter = 0
        self.opt_encoder_decoder.zero_grad(set_to_none=True)

        self.scaler = amp.GradScaler("cuda", enabled=self.use_amp)


    def train(self, t_idx: int, images: torch.Tensor, messages: torch.Tensor) -> Dict[str, torch.Tensor]:
        self.encoder_decoder.train()
        images = images.to(self.device, non_blocking=True)
        messages = messages.to(self.device, non_blocking=True).float()

        if self._accum_counter == 0:
            self.opt_encoder_decoder.zero_grad(set_to_none=True)

        need_sync = ((self._accum_counter + 1) % self.accum_steps == 0)
        sync_ctx = contextlib.nullcontext()
        if self.use_ddp and not need_sync:
            sync_ctx = self.encoder_decoder.no_sync()

        with sync_ctx:
            with amp.autocast("cuda", enabled=self.use_amp):
                encoded_images, noised_images, decoded_logits = self.encoder_decoder(images, messages)

                loss_img = self.criterion_img(encoded_images, images)
                loss_msg = self.criterion_msg(decoded_logits, messages)
                g_loss = self.encoder_weight * loss_img + self.decoder_weight * loss_msg

            if self.use_amp:
                self.scaler.scale(g_loss / self.accum_steps).backward()
            else:
                (g_loss / self.accum_steps).backward()

        self._accum_counter += 1

        if need_sync:
            if self.use_amp:
                self.scaler.step(self.opt_encoder_decoder)
                self.scaler.update()
            else:
                self.opt_encoder_decoder.step()

            self.opt_encoder_decoder.zero_grad(set_to_none=True)
            self._accum_counter = 0

        psnr = self._psnr(encoded_images.detach(), images, max_val=2.0)
        acc, err_rate = self._bit_accuracy_and_error(messages, decoded_logits.detach(), threshold=0.5)

        result = {
            "acc": acc,
            "encoder_weight": torch.tensor(self.encoder_weight, device=self.device),
            "psnr": psnr,
            "g_loss": g_loss.detach(),
            "g_loss_on_encoder": loss_img.detach(),
            "g_loss_on_decoder": loss_msg.detach(),
        }
        return result


    def validation(self, images: torch.Tensor, messages: torch.Tensor) -> Tuple[Dict[str, torch.Tensor], Tuple[torch.Tensor, ...]]:
        self.encoder_decoder.eval()
        with torch.no_grad():
            images = images.to(self.device, non_blocking=True)
            messages = messages.to(self.device, non_blocking=True).float()

            with amp.autocast("cuda", enabled=self.use_amp):
                encoded_images, noised_images, decoded_logits = self.encoder_decoder(images, messages)
                loss_img = self.criterion_img(encoded_images, images)
                loss_msg = self.criterion_msg(decoded_logits, messages)
                g_loss = self.encoder_weight * loss_img + self.decoder_weight * loss_msg

            psnr = self._psnr(encoded_images.detach(), images, max_val=2.0)
            acc, err_rate = self._bit_accuracy_and_error(messages, decoded_logits, threshold=0.5)

        result = {
            "acc": acc,
            "psnr": psnr,
            "g_loss": g_loss,
            "g_loss_on_encoder": loss_img,
            "g_loss_on_decoder": loss_msg,
        }
        return result, (images, encoded_images, noised_images, messages, decoded_logits)


    def _unwrap(self):
        return self.encoder_decoder.module if isinstance(self.encoder_decoder, DDP) else self.encoder_decoder

    def save_model(self, path_encoder_decoder: str):
        torch.save(self._unwrap().state_dict(), path_encoder_decoder)

    def load_model(self, path_encoder_decoder: str):
        self._unwrap().load_state_dict(torch.load(path_encoder_decoder, map_location=self.device), strict=False)

    def flush_accum(self):
        has_grad = any(p.grad is not None for p in self._unwrap().parameters())
        if has_grad:
            self.opt_encoder_decoder.zero_grad(set_to_none=True)
        self._accum_counter = 0


    @staticmethod
    def _psnr(x: torch.Tensor, y: torch.Tensor, max_val: float = 1.0, eps: float = 1e-12) -> torch.Tensor:
        mse = F.mse_loss(x, y)
        mse = torch.clamp(mse, min=eps)
        return 10.0 * torch.log10((max_val ** 2) / mse)

    @staticmethod
    def _bit_accuracy_and_error(target_bits: torch.Tensor, preds: torch.Tensor, threshold: float = 0.5) -> Tuple[torch.Tensor, torch.Tensor]:
        pred_bits = (preds >= threshold).to(target_bits.dtype)
        total = torch.numel(target_bits)
        correct = (pred_bits == target_bits).sum()
        acc = correct.float() / float(total)
        err = 1.0 - acc
        return acc, err