import numpy as np
import torch
from metrics.metrics import Metric

"""
For SOBDRL algorithm, compute the argmax rotation plane activated for each action
Check if the actions of the same group are activated on the same plane
Check if the actions of different groups are activated on different planes
"""

class PlaneMetric(Metric) :
    def __init__(self,algo,nfo,loaders) :
        super().__init__(algo,nfo,loaders)
        self.groups = torch.zeros(self.nfo["n_action"]).int().to(self.device)
        for k,g in enumerate(self.nfo["group"]) :
            for a in g :
                self.groups[a] = k
        

    def __repr__(self) :
        return "plane"
    
    def compute_metrics(self) :
        algo = self.algo
        theta = algo.action_encoder.theta.cpu().detach().numpy()
        theta = np.abs(theta)
        z_dim = algo.z_dim

        planes = []
        for i in range(z_dim) :
            for j in range(i+1,z_dim):
                planes.append({i,j})
        
        n_group = len(self.nfo["group"])
        group_planes = [set() for _ in range(n_group)]
        for k in range(n_group) :
            for g in self.nfo["group"][k] :
                group_planes[k].update(planes[theta[g].argmax()])
        
        # compute mean number of planes activated for each group
        mean_metric = sum([len(g) for g in group_planes])/n_group

        # compute mean common planes for each pair of group
        common_planes = []
        for i in range(n_group) :
            for j in range(i+1,n_group):
                common_planes.append(len(group_planes[i].intersection(group_planes[j])))
        mean_common_planes = sum(common_planes)/len(common_planes)

        return {"mean_planes": mean_metric,
                "mean_common_planes": mean_common_planes}