import einops
import torch
from metrics.metrics import Metric

"""
Compute the prediction error in the latent space
Returns the mean error and the error per group
"""

class ActionMetric(Metric) :
    def __init__(self,algo,nfo,loader) :
        super().__init__(algo,nfo,loader)
        self.groups = torch.zeros(self.nfo["n_action"]).int().to(self.device)
        for k,g in enumerate(self.nfo["group"]) :
            for a in g :
                self.groups[a] = k
        

    def __repr__(self) :
        return "action"
    
    def compute_metrics(self) :
        algo = self.algo
        L = []
        n_groups = self.groups.max().item() + 1
        Lg = {g:[] for g in range(n_groups)}
        for X, A in self.loader :
            X0 = einops.rearrange(X[:,:-1], 'b m ... -> (b m ) ...')
            X1 = einops.rearrange(X[:,1:], 'b m ... -> (b m ) ...')
            A = einops.rearrange(A, 'b m -> (b m )')
            Z1 = algo.encode_image(X1)
            Z1_hat = algo.encode_image(X0, A[:,None])
            error = torch.square(Z1 - Z1_hat).mean(dim=1)
            for g in range(n_groups) :
                Lg[g].append(error[self.groups[A] == g].mean().item())

            L.append(error.mean().item())

        metrics = {"error": sum(L)/len(L)}
        for g, v in Lg.items() :
            metrics[f"error/g{g}"] = sum(v)/len(v)
        
        return metrics