from defense.base_defense import BaseDefense
from CRDR.src.models import build_comp_model
from CRDR.scripts.compress import CustomConfig
import numpy as np
from torchvision import transforms
import torch

class CRDRDefense(BaseDefense):
    """
    CRDR Defense class for applying CRDR to images.
    """
    def __init__(self, quality = 4.0, beta = 5.12, device=None, iterations=1):
        """
        Initialize the CRDR defense with a specified quality and device.
        
        Args:
            quality (float): Quality factor for CRDR in [0.0,4.0] (default: 0.0).
            beta (float): Realism factor for CRDR in [0.0,5.12] (default: 0.0).
            device (torch.device or str): Device to move the model to (default: None).
        """       
        super(CRDRDefense, self).__init__(device,iterations)
        try:
            quality = float(quality)
            beta = float(beta)
        except ValueError:
            raise ValueError("Must be able to convert quality and beta values to float.")
        if quality < 0.0 or quality > 4.0:
            quality = max(0.0, min(quality, 4.0))
            print("Quality factor out of range. Clamping to [0.0, 4.0].")
        if beta < 0.0 or beta > 5.12:
            beta = max(0.0, min(beta, 5.12))
            print("Beta factor out of range. Clamping to [0.0, 5.12].")
        self.quality = quality
        self.beta = beta
        self.args = {
            'config_path':'CRDR/config/crdr.yaml',
            'model_path':'data/crdr.pth.tar',
            'img_dir':None,
            'save_dir':None,
            'quality':quality,
            'beta':beta,
            'decompress':True,
            'device':device}
        self.opt = MyConfig().get_opt(self.args)
        self.model =  build_comp_model(self.opt).to(self.opt.device)
        try:
            self.model.load_learned_weight(ckpt_path=self.opt.model_path)
        except FileNotFoundError:
            raise FileNotFoundError(f"Model weights for '{self.opt.model_path}' not found. Please provide valid weights in the data folder.")
        self.model.codec_setup()
        self.model = self.model.to(self.device)
        self.model.eval()
        self.device_attributes.append('model')

    def _defense(self, x):
        """
        Apply the defense to the input tensor.
        
        Args:
            x (torch.Tensor): Input tensor of shape (N, C, H, W).
        
        Returns:
            torch.Tensor: iterations times compressed and decompressed tensor.
        """
        x = transforms.Resize((256,256))(x)
        for _ in range(self.iterations):
            x = self.model(x,rate_ind = self.quality,beta = self.beta,is_train = True)['fake_images']
        x = transforms.Resize((224,224))(x)

        # save the images to debug
        #self._save_images(x, 'debug_images')

        return x
    
    def get_bitrate(self,x):
        """
        Get the bitrate of the input tensor.
        
        Args:
            x (torch.Tensor): Input tensor of shape (N, C, H, W).
        
        Returns:
            float: Bitrate of the input tensor.
        """
        y_likelihoods = []
        z_likelihoods = []
        x = transforms.Resize((256,256))(x)
        for _ in range(self.iterations):
            output = self.model(x,rate_ind = self.quality,beta = self.beta,is_train = True)
            x = output['fake_images']
            y_likelihoods.append(self._likelihood_to_bit(output['likelihoods']['y']))
            z_likelihoods.append(self._likelihood_to_bit(output['likelihoods']['z']))
        x = transforms.Resize((224,224))(x)
        return sum(y_likelihoods)/len(y_likelihoods), sum(z_likelihoods)/len(z_likelihoods),x
    
    def _likelihood_to_bit(self, likelihood):
        sum_dims = tuple(range(1, likelihood.ndim))
        bitcost = -(torch.log(likelihood).sum(dim=sum_dims)) / np.log(2)
        return bitcost
    
    
    
class MyConfig(CustomConfig):

    @classmethod
    def get_opt(cls,args) -> 'CustomConfig':
        arg_dict = cls.arg_parse(args)
        filename = arg_dict['config_path']
        cfg_dict, cfg_text, loaded_yamls = cls._file2dict_yaml(filename)
        arg_dict = cls._merge_a_into_b(arg_dict, cfg_dict)
        arg_dict['is_train'] = False
        return cls(arg_dict, cfg_text=cfg_text, filename=filename)

    @staticmethod
    def arg_parse(args):
        out_dict = args # argparse.Namespace -> Dict
        return out_dict