import numpy as np
from data_generator import load_data, get_generator
from torch import nn
from itertools import combinations
import torch
from metrics.metrics import Metric
import einops
from algos.ae import AE

""""
Compute the independance metric representing the distanglement
from Linear Disentangled Representations and Unsupervised Action Estimation, Painter et Al, 2020
"""

def compute_P(r,epsilon,M):
    if r==1 :
        return epsilon*(M+2)
    else :
        return epsilon * ((1-r**(M+1))/(1-r) + 1)

class ValuesMetric(Metric) :
    def __init__(self, algo, nfo, loaders):
        super().__init__(algo, nfo, loaders)
        self.generator = get_generator(nfo["environment"], specs=nfo["specs"])
        self.data_images = load_data(nfo["dataname"])[-2].to(self.device)

    def __repr__(self) :
        return "values"
    
    def compute_metrics(self):
        if isinstance(self.algo, AE):
            return {}
        algo = self.algo
        n_action = self.nfo["n_action"]
        N = len(self.data_images)

        #compute epsilon
        idxs = np.arange(N)
        epsilon = 0
        for a in range(n_action):
            A = a*torch.ones(N, dtype=torch.int).to(self.device)[:,None]
            idxs_p = self.generator.group.transition(idxs, a*np.ones_like(idxs, dtype=int))
            images_p = self.data_images[idxs_p].to(self.device)
            with torch.no_grad():
                Zp_hat = algo.encode_image(self.data_images, A)
                Zp = algo.encode_image(images_p)

            error = torch.norm(Zp - Zp_hat, dim=1)

            epsilon = max(error.max(), epsilon)

        #r
        Az = algo.encode_action(torch.arange(n_action).int().to(self.device)).cpu().detach().numpy()
        norms = np.linalg.norm(Az, 2, axis=(-1,-2))
        r = norms.max()

        #delta
        Z = algo.encode_image(self.data_images).cpu().detach().numpy()

        # compute minimum pairwise distance with batch
        min_dist = np.inf
        batch_size = 256
        for i in range(0, N, batch_size):
            Z_batch = Z[i:i+batch_size]
            dist = np.linalg.norm(
                einops.rearrange(Z_batch, 'b ... -> b 1 ...') - einops.rearrange(Z, 'b ... -> 1 b ...'),
                axis=(-1)
            )
            dist[np.arange(len(Z_batch)), i+np.arange(len(Z_batch))] = np.inf
            min_dist = min(np.min(dist), min_dist)


        return {
            "epsilon": epsilon,
            "r": r,
            "delta": min_dist,
            "eta_M2": compute_P(r, epsilon, 2),
            "eta_M3": compute_P(r, epsilon, 3),
        }