import torch
from abc import ABC, abstractmethod
import numpy as np

#Losses or backends registry?
#Check use x or x_enc


class OptimisationState:
    def __init__(self, model, z, z_factual, x_enc, y_enc, x_factual, y_factual, y_target, it, n_it):
        self.model = model
        self.z = z
        self.z_factual = z_factual
        self.x_enc = x_enc
        self.y_enc = y_enc
        self.x_factual = x_factual
        self.y_factual = y_factual
        self.y_target = y_target
        self.it = it
        self.n_it = n_it

    def __str__(self):
        return (f"OptimisationState(model={self.model}, z={self.z}, z_factual={self.z_factual}, "
                f"x_enc={self.x_enc}, y_enc={self.y_enc}, x_factual={self.x_factual}, "
                f"y_factual={self.y_factual}, y_target={self.y_target}, it={self.it}, n_it={self.n_it})")

class Loss(ABC):
    @abstractmethod
    def loss(self, opt_state: OptimisationState):
        pass

class ClassificationLoss(ABC):
    def loss(self, opt_state: OptimisationState):
        return opt_state.model.compute_loss(opt_state.y_enc, opt_state.y_target)

class DistanceLoss(Loss):
    def __init__(self, norm=1, mad=False, mad_data=None, dist_weight=None):
        self.norm = norm
        self.mad = mad
        self.dist_weight = dist_weight

        if self.mad:
            med = np.median(mad_data, axis=0)
            mad = np.median(np.abs(mad_data - med), axis=0)
            mad[mad == 0] = 1e-9 
            self.dist_weight = 1 / mad
        
        if self.dist_weight is not None:
            self.dist_weight = torch.tensor(self.dist_weight, dtype=torch.float32)

    def loss(self, opt_state: OptimisationState):
        if self.dist_weight is None:
            return torch.dist(opt_state.z, opt_state.z_factual, p=self.norm)
        else:
            return torch.norm(self.dist_weight.to(opt_state.z.device) * (opt_state.z - opt_state.z_factual), 1)

class EnergyLoss(Loss):
    def __init__(self, reg_strength: float = 1e-3, decay: float = 0.9):
        self.decay = decay
        self.reg_strength = reg_strength
    
    def loss(self, opt_state):
        model = opt_state.model.pytorch_model
        x_prime = opt_state.x_enc           
        t = torch.argmax(opt_state.y_target)
        step = opt_state.it + 1
        max_steps = opt_state.n_it

        # polynomial decay multiplier 
        b = round(max_steps / 25)
        a = b / 10
        phi = (1.0 + step / b) ** (-self.decay) * a

        # generative loss (negative logit of target class) 
        logits = torch.softmax(model(x_prime), dim=-1)
        gen_loss = -logits[t]

        # total loss
        if self.reg_strength == 0.0:
            loss = phi * gen_loss
        else:
            reg_loss = torch.norm(gen_loss) ** 2
            loss = phi * (gen_loss + self.reg_strength * reg_loss)

        return loss
