from collections import OrderedDict
import re
import torchvision.transforms as transforms
from PIL import Image
import torch
import math

import sys
sys.path.append('..')
from data_generate import transformations as transfm


def enlist_transformation(img_resize=None, resize_interpolation='BILINEAR', is_grayscale=False, device=None, img_normalise=True):
    transform_ls = []

    # whetehr to add the image resize
    if img_resize is not None:
        transform_ls.append(transforms.Resize(size=(img_resize, img_resize), interpolation=getattr(Image, resize_interpolation)))
    
    # whether to add the gray transformation
    if is_grayscale:
        transform_ls.append(transforms.Grayscale())

    # whether to add the tensor transformation
    transform_ls.append(transforms.ToTensor())

    # whether to put into the device
    if device is not None:
        #transform_ls.append(transfm.ToDevice(device=device))
        transform_ls.append(transfm.ToDevice(device='cpu'))
    
    # whether to add the norlization operation
    if img_normalise:
        transform_ls.append(transfm.NormaliseMinMax())
    
    return transform_ls


def compute_acc(pred, label, reduction='mean'):
    result = (pred == label).float()
    if reduction == 'none':
        return result.detach()
    elif reduction == 'mean':
        return result.mean().item()

def config_inner_args(inner_args):
    if inner_args is None: 
        inner_args = dict()

    inner_args['reset_classifier'] = inner_args.get('reset_classifier') or False
    inner_args['n_step'] = inner_args.get('n_step') or 5
    inner_args['encoder_lr'] = inner_args.get('lr_inner') or 0.01
    inner_args['classifier_lr'] = inner_args.get('lr_inner') or 0.01
    inner_args['prompt_lr'] = inner_args.get('lr_prompt') or 0.01
    inner_args['decoder_lr'] = inner_args.get('lr_decoder') or 0.01
    inner_args['momentum'] = inner_args.get('momentum') or 0.
    inner_args['weight_decay'] = inner_args.get('weight_decay') or 0.
    inner_args['first_order'] = inner_args.get('first_order') or False
    inner_args['frozen'] = inner_args.get('frozen') or []

    return inner_args



def gaussian_sampling(mean_, cov_, detach_mean_cov=False):
    
    # it is confirm that has no problem
    params_sample_size = [mean_.size()]
    sample = (mean_.detach() + cov_.detach().sqrt()
                 * torch.randn(*params_sample_size, dtype=mean_.dtype, device=mean_.device)).requires_grad_(True) \
                    if detach_mean_cov \
                else mean_ + cov_.sqrt() * torch.randn(*params_sample_size, dtype=mean_.dtype, device=mean_.device)
    return sample

def kl_divergence_gaussian(mean_, cov_, input_ = None, prior_mean=None, prior_std=None, mixture_=True, sig1_=0.0, sig2_=6.0, weight_=0.25):

    if prior_mean is not None or prior_std is not None:
        pass
    elif input_ is None:
        prior_mean = torch.zeros_like(mean_)
        prior_std = torch.ones_like(cov_)

        mean_diff = mean_ - prior_mean
        sig_q_inv = 1 / cov_
        kl_layer = torch.log(cov_).sum() - torch.log(prior_std).sum() - prior_mean.numel() + (sig_q_inv * prior_std).sum() \
                   + ((mean_diff * sig_q_inv) * mean_diff).sum()
        
        return kl_layer/2
    elif mixture_:
        prior_mean = torch.zeros_like(mean_)
        prior_std = -torch.ones_like(cov_)
        pri_distri1 = torch.distributions.Normal(prior_mean, (prior_std*sig1_).exp())
        pri_distri2 = torch.distributions.Normal(prior_mean, (prior_std*sig2_).exp())

        log_pro_pri = weight_*pri_distri1.log_prob(input_).sum() + (1-weight_)*pri_distri2.log_prob(input_).sum()
        log_pro_post = (-math.log(math.sqrt(2 * math.pi)) - torch.log(cov_) - ((input_ - mean_) ** 2) / (2 * cov_ ** 2)).sum()

        return log_pro_post - log_pro_pri

    else:
        # use a simgle prior based on the input
        prior_mean = torch.zeros_like(mean_)
        prior_std = torch.ones_like(cov_)
        pri_distri = torch.distributions.Normal(prior_mean, prior_std)

        log_pro_pri = pri_distri.log_prob(input_).sum()
        log_pro_post = (-math.log(math.sqrt(2 * math.pi)) - torch.log(cov_) - ((input_ - mean_) ** 2) / (2 * cov_ ** 2)).sum()

        return log_pro_post - log_pro_pri