import torch
from typing import Tuple
import einops
import numpy as np
import torch.nn.functional as F

def scale(clip, x, b, c, p):
    f = (x + b)**p - c

    return (clip * f)/((clip + f) * (x + 1e-5))


def scale_softplus(clip, x, b):
    if b > 100:
        return torch.minimum(clip/(x+1e-5), torch.ones_like(x))
    else:
        sp = np.log(1+np.exp(b))
        sig = 1/(1 + np.exp(-b))

        k = sp/(sig * clip)
        a = clip/sp

        mag = clip - a * F.softplus(-1 * k * x + b)

        return mag/(x+1e-5)


def scale_better(clip, x, p):
    p = np.exp(p)
    a = -1 * clip * (clip * p)**p
    f = (x + clip*p)**p
    
    
    log_a = np.log(clip) + p * np.log(clip * p)
    log_f = p * torch.log(x + clip*p)
    
    mag = -1 * torch.exp(log_a - log_f) + clip
    
    # mag = a/f + clip
    return mag/(x+1e-5)

    # return (clip * f)/((clip + f) * (x + 1e-5))


def scale_softmax(clip, x, T):
    b = -np.log(1 - np.exp(-1 * T * clip))/T - clip
    new_scale = -F.softplus(-1 * (x + b), T) + clip
    return new_scale/(x+1e-5)

def get_mean_cov_scales(x):
    mean_per_img = x.mean([2, 3]) #B x C
    expanded = einops.einsum(x, x, 'b c1 h w, b c2 h w -> b c1 c2')/(x.shape[2] * x.shape[3])
    
    expanded_u = torch.triu(expanded)
    # expanded_u2 = torch.triu(expanded2)

    mean_norms_sq = torch.sum(mean_per_img**2, 1)
    cov_norms_sq = torch.sum(expanded_u**2, [1, 2]) #use the upper triangular to compute the norms
    
    return mean_per_img, expanded, mean_norms_sq, cov_norms_sq
    
    
    


def get_stats_conv(x, running_mean = None, running_var = None, sigmoid_scale = 1.0, class_sigmoid_scale = 1.0, main_proj_mat = None, class_proj_mat = None, copy_main = False):
    # x = input[0]


    # if running_mean is not None:
    #     x = (x - running_mean[None, :, None, None]) * torch.rsqrt(running_var[None, :, None, None] + 1e-5)
    
    
    
    if main_proj_mat is None:
        main_features_int = x
    else:
        main_features_int = einops.einsum(x, sigmoid_scale * main_proj_mat, 'b c h w, c o -> b o h w')
    x_main = 2 * F.sigmoid(main_features_int) - 1
    
    crap_main = get_mean_cov_scales(x_main)
    
    if copy_main:
        crap_class = crap_main
    
    else:
        if class_proj_mat is None:
            class_features_int = x
        else:
            class_features_int = einops.einsum(x, class_sigmoid_scale * class_proj_mat, 'b c h w, c o -> b o h w')
        x_class = 2 * F.sigmoid(class_features_int) - 1
    
    
        crap_class = get_mean_cov_scales(x_class)
        
    
    return crap_main, crap_class
    
def get_per_example_bn_grads(x, d_out, running_mean, running_var):
    #out = (in-rmean)*rsqrt(rvar) * scale + shift
    #dshift = dout
    #dscale = dout * (in-rmean)*rsqrt(rvar) 
    
    #input B C H W
    #output B C H W
    
    # x = input[0]    
    
    # intermediate = (x - running_mean[None, :, None, None]) * torch.rsqrt(running_var[None, :, None, None] + 1e-5) # B C H W    
    intermediate = x
    
    # print(intermediate.shape)
    # print(d_out[0].shape)
    
    return torch.sum(d_out[0], [2, 3]), torch.sum(d_out[0] * intermediate, [2,3])


@torch.jit.script
def get_stats_linear(input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    x = input
    
    mean = x.mean([0])
    cov = torch.cov(x.T, correction = 0) + mean.reshape(-1, 1) @ mean.reshape(1, -1)
    
    return mean, cov

class CovStatsHook():    
    # def __init__(self, module, on_output = False, index_start = 0, index_end = 64, m1_clip = 999, m2_clip = 999, 
    #              mask_coef = None, proj_mat = None, reduce = True, aug_groups = 1, clip_power = 4.0, mmd_proj_mat = None, backwards_only = False, bwd_sigmoid_scale = 1.0, bwd_proj_mat = None, sigmoid_scale = 1.0):
    def __init__(self, module, main_proj_mat = None, class_proj_mat = None, backwards_only = False, bwd_sigmoid_scale = 1.0, bwd_proj_mat = None, sigmoid_scale = 1.0, class_sigmoid_scale = 1.0, copy_main = False):
        
        self.hook = module.register_forward_hook(self.hook_fn)
        self.bwd_hook = module.register_full_backward_pre_hook(self.bwd_hook_fn)
        self.main_proj_mat = main_proj_mat
        self.class_proj_mat = class_proj_mat
        self.backwards_only = backwards_only
        self.bwd_sigmoid_scale = bwd_sigmoid_scale
        self.bwd_proj_mat = bwd_proj_mat
        self.sigmoid_scale = sigmoid_scale
        self.class_sigmoid_scale = class_sigmoid_scale
        self.copy_main = copy_main

        self.module = module
        
        
    def hook_fn(self, module, input: Tuple[torch.Tensor], output) -> None:
        # if self.on_output:
        #     self.mean, self.cov = get_stats_linear(output)
        # else:
        # value = [output.clone()] if self.on_output else input
        self.centered_x = (input[0] - self.module.running_mean[None, :, None, None]) * torch.rsqrt(self.module.running_var[None, :, None, None] + 1e-5)
        
        # self.last_inputs = input
        if not self.backwards_only:
            # get_stats_conv(input: Tuple[torch.Tensor], running_mean = None, running_var = None, sigmoid_scale = 1.0, class_sigmoid_scale = 1.0, main_proj_mat = None, class_proj_mat = None)
            # self.mean_unclipped, self.cov_unclipped, self._, self._, self.mean_unreduced_unclipped, self.cov_unreduced_unclipped, self.features, self.mmd_features, self.mean_norms_sq, self.cov_norms_sq = get_stats_conv(value, self.index_start, self.index_end, self.m1_clip, self.m2_clip, self.mask_mat, self.proj_mat, self.module.running_mean, 
            # self.module.running_var, second_mean = self.full_data_mean, second_var = self.second_proj_mat, reduce = self.reduce, aug_groups = self.aug_groups, clip_power = self.clip_power, mmd_proj_mat = self.mmd_proj_mat, sigmoid_scale = self.sigmoid_scale)
            
            main_stuff, class_stuff = get_stats_conv(self.centered_x, running_mean = self.module.running_mean, running_var = self.module.running_var, sigmoid_scale = self.sigmoid_scale, class_sigmoid_scale = self.class_sigmoid_scale, main_proj_mat = self.main_proj_mat, class_proj_mat = self.class_proj_mat, copy_main = self.copy_main)
            
            self.mean_unreduced_unclipped, self.cov_unreduced_unclipped, self.mean_norms_sq, self.cov_norms_sq = main_stuff
            
            self.mean_unreduced_unclipped_class, self.cov_unreduced_unclipped_class, self.mean_norms_sq_class, self.cov_norms_sq_class = class_stuff
        
    def bwd_hook_fn(self, module, d_out):
        self.d_shift, self.d_scale = get_per_example_bn_grads(self.centered_x, d_out, self.module.running_mean, self.module.running_var)
        
        # if self.bwd_proj_mat is not None:
        #     self.d_shift = self.d_shift @ self.bwd_proj_mat
        #     self.d_scale = self.d_scale @ self.bwd_proj_mat
        
        self.d_shift_unclipped = 2 * torch.sigmoid(self.bwd_sigmoid_scale * self.d_shift) - 1
        self.d_scale_unclipped = 2 * torch.sigmoid(self.bwd_sigmoid_scale * self.d_scale) - 1
        # self.d_combined = torch.concatenate([self.d_shift, self.d_scale, self.d_shift ** 2, self.d_scale ** 2], 1)
        
        # self.d_shift_cov = einops.einsum(self.d_shift, self.d_shift, 'b c1, b c2 -> b c1 c2')
        # self.d_scale_cov = einops.einsum(self.d_scale, self.d_scale, 'b c1, b c2 -> b c1 c2')
        # self.d_shift_cov = torch.zeros([self.d_shift.shape[0], 8, 8 ], device = self.d_shift.device)
        # self.d_scale_cov = torch.zeros([self.d_shift.shape[0], 8, 8 ], device = self.d_shift.device)
        
        # self.d_combined = torch.concatenate([self.d_shift, self.d_scale], 1)
        # if self.aug_groups > 1:
            # self.d_combined = self.d_combined.view(self.aug_groups, -1, self.d_combined.shape[-1]).mean(0)


    def close(self):
        self.hook.remove() 
        self.bwd_hook.remove()
        

def get_stats_conv_norm(input: Tuple[torch.Tensor], index_start: int, index_end:int, mask_mat = None, proj_mat = None, running_mean = None, running_var = None) -> Tuple[torch.Tensor, torch.Tensor]:
    x = input[0]

    if running_mean is not None:
        x = (x - running_mean[None, :, None, None]) * torch.rsqrt(running_var[None, :, None, None] + 1e-5)

    x = x[:, index_start:index_end]

    if proj_mat is not None:
        x = einops.einsum(x, proj_mat, 'b ci h w, ci co -> b co h w')

    mean_per_img = x.mean([2, 3]) #B x C
    mean = mean_per_img.mean([0])
    expanded = einops.einsum(x, x, 'b c1 h w, b c2 h w -> b c1 c2')/(x.shape[2] * x.shape[3])


    if mask_mat is not None:
        expanded = expanded / mask_mat[None, :, :]

    expanded = torch.triu(expanded)
        
    mean_norms = torch.sqrt(torch.sum(mean_per_img**2, 1))
    cov_norms = torch.sqrt(torch.sum(expanded**2, [1, 2]))

    return mean_norms, cov_norms
        
def scale_cov_hooks(cov_stat_hooks, proj_config):
    sq_norms = None
        
    if proj_config.clip_type == 'global':
        for csh_i, cov_stat_hook in enumerate(cov_stat_hooks):
            if csh_i == 0:    
                sq_norms = torch.zeros_like(cov_stat_hook.mean_norms_sq)
            sq_norms += cov_stat_hook.mean_norms_sq
            sq_norms += proj_config.cov_scale_multiplier * cov_stat_hook.cov_norms_sq
            if csh_i in proj_config.class_include_indices:
                sq_norms += proj_config.class_scalar * cov_stat_hook.mean_norms_sq_class
                sq_norms += proj_config.cov_scale_multiplier * proj_config.class_scalar * cov_stat_hook.cov_norms_sq_class
            
        scales = scale_softplus(proj_config.total_clip, torch.sqrt(sq_norms), proj_config.clip_power)
        
        # print(scales)
        
        for csh_i, cov_stat_hook in enumerate(cov_stat_hooks):
            cov_stat_hook.mean_unreduced = cov_stat_hook.mean_unreduced_unclipped * scales[:, None]
            cov_stat_hook.cov_unreduced = cov_stat_hook.cov_unreduced_unclipped * scales[:, None, None]
            
            cov_stat_hook.mean = cov_stat_hook.mean_unreduced.mean(0)
            cov_stat_hook.cov = cov_stat_hook.cov_unreduced.mean(0)
            
            
            # if csh_i in class_include_indices:
            
            #note that we do scaling for things not included, but in practice these aren't used so it's ok
            cov_stat_hook.mean_unreduced_class = cov_stat_hook.mean_unreduced_unclipped_class * scales[:, None]
            cov_stat_hook.cov_unreduced_class = cov_stat_hook.cov_unreduced_unclipped_class * scales[:, None, None]   
            
    elif proj_config.clip_type == 'hybrid_all':
        # print("HYBRID CLIPPING")
        
        for csh_i, cov_stat_hook in enumerate(cov_stat_hooks):
            scales1_all = scale_softplus(proj_config.m1_clips[csh_i], torch.sqrt(cov_stat_hook.mean_norms_sq), proj_config.clip_power)
            scales2_all = scale_softplus(proj_config.m2_clips[csh_i], torch.sqrt(cov_stat_hook.cov_norms_sq), proj_config.clip_power)
            # sq_norms += cov_stat_hook.cov_norms_sq
            
            cov_stat_hook.mean_unreduced = cov_stat_hook.mean_unreduced_unclipped * scales1_all[:, None]
            cov_stat_hook.cov_unreduced = cov_stat_hook.cov_unreduced_unclipped * scales2_all[:, None, None]
            
            cov_stat_hook.mean = cov_stat_hook.mean_unreduced.mean(0)
            cov_stat_hook.cov = cov_stat_hook.cov_unreduced.mean(0)
            
            # if csh_i in proj_config.class_include_indices:
            scales1_class = scale_softplus(proj_config.m1_clips_class[csh_i]/np.sqrt(proj_config.class_scalar), torch.sqrt(cov_stat_hook.mean_norms_sq_class), proj_config.clip_power)
            scales2_class = scale_softplus(proj_config.m2_clips_class[csh_i]/np.sqrt(proj_config.class_scalar), torch.sqrt(cov_stat_hook.cov_norms_sq_class), proj_config.clip_power)
            
            cov_stat_hook.mean_unreduced_class = cov_stat_hook.mean_unreduced_unclipped_class * scales1_class[:, None]
            cov_stat_hook.cov_unreduced_class = cov_stat_hook.cov_unreduced_unclipped_class * scales2_class[:, None, None]   


        
def scale_cov_hooks_bwd(cov_stat_hooks, proj_config):
    sq_norms = None
        
    for csh_i, cov_stat_hook in enumerate(cov_stat_hooks):
        if csh_i == 0:    
            sq_norms = torch.zeros_like(cov_stat_hook.mean_norms_sq)
        sq_norms += torch.sum(cov_stat_hook.d_shift_unclipped**2, 1)
        sq_norms += torch.sum(cov_stat_hook.d_scale_unclipped**2, 1)
        
    scales = scale_softplus(proj_config.total_clip_bwd, torch.sqrt(sq_norms + 1e-5), proj_config.clip_power)
    
    # print(scales)
    
    for csh_i, cov_stat_hook in enumerate(cov_stat_hooks):
        cov_stat_hook.d_shift = cov_stat_hook.d_shift_unclipped * scales[:, None]
        cov_stat_hook.d_scale =  cov_stat_hook.d_scale_unclipped * scales[:, None]
        
        