from kernel_tests import MMD
from c2st import C2ST
from utils import *
import torch as ch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import scipy.stats as scistats

class TestStatistic(nn.Module):

    def __init__(self):
        super(TestStatistic,self).__init__()
    
    def __call__(self, sample_1, sample_2):
        r"""Evaluate the statistic.

        Arguments
        ---------
        sample_1: :class:`torch:torch.autograd.Variable`
        sample_2: variable of shape (n_2, d)
        alphas : list of :class:`float`
            The kernel parameters.

        Returns
        -------
        :class:`float`
            The test statistic."""
        stat = self.statistic(sample_1, sample_2)
        return stat
        
    def statistic(self, sample_1, sample_2): 
        raise NotImplementedError
        
    def pval(self, sample_1, sample_2):
        stat_val = np.abs(self.statistic(sample_1, sample_2).detach().cpu().numpy())
        ts = scistats.t.cdf(stat_val, df = sample_1.shape[0]-1)
        return 2 * (1 - ts)
        
        
class Flatten(nn.Module): 
    def forward(self,x): 
        return x.view(x.size(0),-1)

    
class MMD_Stat(TestStatistic):
    r"""A test statistic that runs the MMD test on the two sample tensors from
    two distributions. For example, the two inputs could be a batch of
    images from CIFAR and a batch of images from CIFAR-C.
    
    Arguments
    ---------
    n_1: tensor
        A batch of samples from the first distribution.
    n_2: tensor
        A batch of samples from the second distribution."""
    
    def __init__(self, encoder=None):
        super(MMD_Stat,self).__init__()
        if encoder is None:
            self.encoder = self.initialize_default_enc()
        else:
            self.encoder = encoder.cuda()
            
    def initialize_default_enc(self, encoding_dim=32):
        r"""Call this to initialize the UAE used
        for the MMD Tester."""
        default_enc = nn.Sequential(
            nn.Conv2d(3,64,4,stride=2), 
            nn.ReLU(), 
            nn.Conv2d(64,128,4,stride=2), 
            nn.ReLU(), 
            nn.Conv2d(128,512,4,stride=2), 
            Flatten(), 
            nn.Linear(2048,encoding_dim)
        )
        return default_enc
    
    def reinitialize_default_encoder(self, encoding_dim=32):
        r"""Sometimes, re-initializing the UAE will
        lead to a stronger test."""
        self.encoder = self.initialize_default_enc()
        
    def update_encoder(self, encoder):
        self.encoder = encoder
        
    def __call__(self, sample_1, sample_2, alphas=[500], encoder=None, use_new_random_encoders=False): 
        stat = self.statistic(sample_1, sample_2, alphas, encoder, use_new_random_encoders)
        return stat
    
    def statistic(self, sample_1, sample_2, alphas, encoder, use_new_random_encoders=False):
        if encoder is None: 
            encoder = self.encoder.to(sample_1.device)
        
        if use_new_random_encoders:
            new_encoder = self.initialize_default_enc()
        else:
            new_encoder = encoder
        ds1 = new_encoder(sample_1)
        ds2 = new_encoder(sample_2)
        x1,x2 = ds1.view(ds1.size(0),-1), ds2.view(ds2.size(0),-1)
        mmd = MMD(x1.size(0), x2.size(0))
        statistic = mmd(x1, x2, alphas=alphas, ret_matrix=False)
        return statistic
    
    def pval(self, sample_1, sample_2, alphas=[500], n_permutations=1000, verbose=True):
        ds1 = self.encoder(sample_1)
        ds2 = self.encoder(sample_2)
        with ch.no_grad(): 
            x1,x2 = ds1.view(ds1.size(0),-1), ds2.view(ds2.size(0),-1)
            mmd = MMD(x1.size(0), x2.size(0))
            statistic, distances = mmd(x1, x2, alphas=alphas, ret_matrix=True)
            p, _ = mmd.pval(distances, n_permutations=n_permutations)
            if verbose:
                print_message_pvals(p)
            return p
        

class Luminance_Stat(TestStatistic):
    r"""A Test Statistic that runs the 1D-Luminance test on the two sample tensors from
    two distributions. For example, the two inputs could be a batch of
    images from CIFAR and a batch of images from CIFAR-C.
    
    Arguments
    ---------
    n_1: tensor
        A batch of samples from the first distribution.
    n_2: tensor
        A batch of samples from the second distribution."""
    
    def __init__(self):
        pass
    
    def get_lum_skew(self, ds):
        assert len(ds.shape) == 4 # Want ds to have shape B x 3 x 224 x 224 for ImageNet
        gray = rgb_to_gray(ds)
        gray_reshaped = gray.reshape(gray.shape[0], -1)
        return skew(gray_reshaped)
    
    def statistic(self, sample_1, sample_2):
        skew1 = self.get_lum_skew(sample_1)
        skew2 = self.get_lum_skew(sample_2)
        stat = ttest_ind(skew1, skew2)
        return stat
    

class RFR_Stat(TestStatistic):
    r"""A Test Statistic that runs the 1D-Random-Filter Response test on the two sample tensors from
    two distributions. For example, the two inputs could be a batch of
    images from CIFAR and a batch of images from CIFAR-C.
    
    Arguments
    ---------
    n_1: tensor
        A batch of samples from the first distribution.
    n_2: tensor
        A batch of samples from the second distribution."""
    
    def __init__(self):
        pass
    
    def make_rand_kernel(self, seed):
        """
        Make a random 8x8 filter, as is done in the 1D-tests paper:
        https://arxiv.org/abs/1708.02688
        """
        if seed is not None:
            np.random.seed(seed)
        kernel = np.random.uniform(size=(1, 1, 8, 8)).astype('float32')
        kernel = (kernel - kernel.mean()) / np.sqrt(((kernel - kernel.mean()) ** 2).sum())
        return ch.from_numpy(kernel).cuda()

    def img_rand_filter_response(self, imgs, seed):
        """
        Convolve the input batch of images with one random kernel after converting to grayscale first.
        We convert to grayscale first because that's what they do in the 1D-tests paper's code.
        Arguments
        ---------
        imgs : np.array
            A batch of RGB images

        Returns
        -------
        np.array
            A batch of filter responses (almost the same shape as before, just a little smaller after convolution)

        """
        gray = rgb_to_gray(imgs)
        kernel = self.make_rand_kernel(seed)
        gray = gray.cuda()
        response = F.conv2d(gray, kernel)
        response = response.to(device=imgs.device)
        return response

    def get_kurtosis(self, ds, seed):
        rfr = self.img_rand_filter_response(ds, seed)
        rfr_reshaped = rfr.reshape(rfr.shape[0], -1)
        return kurtosis(rfr_reshaped)

    def __call__(self, sample_1, sample_2, seed=0):
        stat = self.statistic(sample_1, sample_2, seed)
        return stat
    
    def statistic(self, ds1, ds2, seed=0):
        kurtosis1 = self.get_kurtosis(ds1, seed)
        kurtosis2 = self.get_kurtosis(ds2, seed)
        stat = ttest_ind(kurtosis1, kurtosis2)
        return stat


class Contrast_Stat(TestStatistic):
    r"""A Test Statistic that runs the 1D-Luminance test on the two sample tensors from
    two distributions. For example, the two inputs could be a batch of
    images from CIFAR and a batch of images from CIFAR-C.
    
    Based on the definition of RMS contrast here: https://en.wikipedia.org/wiki/Contrast_(vision)
    
    Arguments
    ---------
    n_1: tensor
        A batch of samples from the first distribution.
    n_2: tensor
        A batch of samples from the second distribution."""
    
    def __init__(self):
        pass
    
    def get_contrast(self, ds):
        assert len(ds.shape) == 4 # Want ds to have shape B x 3 x 224 x 224 for ImageNet
        gray = rgb_to_gray(ds)
        intensity_means = gray.mean(axis=[1,2,3])
        diff_from_means = gray - intensity_means.view([-1, 1, 1, 1])
        sq_diff = ch.square(diff_from_means)
        rms_contrast = ch.sqrt(sq_diff.mean(axis=[1,2,3]))
        return rms_contrast
    
    def statistic(self, sample_1, sample_2):
        rms_contrast1 = self.get_contrast(sample_1)
        rms_contrast2 = self.get_contrast(sample_2)
        stat = ttest_ind(rms_contrast1, rms_contrast2)
        return stat
