import torch
import numpy as np
import sys
from Utils.context_fid import Context_FID
from Utils.metric_utils import display_scores
import abc
from jaxtyping import Array, PRNGKeyArray, Float, Scalar, Bool
import scipy
import os
import warnings
from Utils.metric_utils import display_scores
import jax.random as random
import equinox as eqx
import tqdm
from Utils.discriminative_metric_jax import discriminative_score_metrics

class AbstractResult(eqx.Module):
  mean: Scalar
  sigma: Scalar

  def __init__(self, mean: Scalar, sigma: Scalar):
    self.mean = mean
    self.sigma = sigma

class ContextFIDResult(AbstractResult):
  pass

class CorrelationalScore(AbstractResult):
  pass

class DiscriminativeScore(AbstractResult):
  pass

class PredictiveScore(AbstractResult):
  pass

def get_context_fid_score(
    original_data: Float[Array, 'T D'],
    generated_data: Float[Array, 'T D'],
    *,
    iterations: int = 5
  ):
  context_fid_score = []
  original_data = np.array(original_data)
  generated_data = np.array(generated_data)

  for i in range(iterations):
      context_fid = Context_FID(original_data, generated_data)
      context_fid_score.append(context_fid)
      print(f'Iter {i}: ', 'context-fid =', context_fid, '\n')

  mean, sigma = display_scores(context_fid_score)
  return ContextFIDResult(mean=mean, sigma=sigma)


def get_discriminative_score(
    original_data: Float[Array, 'T D'],
    generated_data: Float[Array, 'T D'],
    *,
    iterations: int = 5
  ):
  discriminative_score = []

  key = random.PRNGKey(0)
  keys = random.split(key, iterations)
  for i, key in tqdm.tqdm(list(enumerate(keys))):
      temp_disc, fake_acc, real_acc = discriminative_score_metrics(original_data[:], generated_data[:original_data.shape[0]], key)
      discriminative_score.append(temp_disc)
      print(f'Iter {i}: ', temp_disc, ',', fake_acc, ',', real_acc, '\n')

  mean, sigma = display_scores(discriminative_score)
  return DiscriminativeScore(mean=mean, sigma=sigma)
