from typing import Dict
import torch as th
import torch.nn.functional as F
from pathlib import Path
import sys
sys.path.append(str(Path.cwd()))
from Diffusion import logger, dist_util
from torch.nn.functional import interpolate


from script_util import create_gaussian_diffusion, create_image_cond_score_model, create_anti_causal_predictor, \
    create_score_model_


def get_models_from_config(config, noised=True):
    diffusion = create_gaussian_diffusion(config)
    
    if config.sampling.image_conditional:
        model = create_image_cond_score_model(config)
        model.load_state_dict(
            dist_util.load_state_dict(config.sampling.image_conditioned_model_path, map_location=dist_util.dev())
        )
    else:
        model = create_score_model_(config)
        model.load_state_dict(
            dist_util.load_state_dict(config.sampling.model_path, map_location=dist_util.dev())
        )
    model.to(dist_util.dev())
    
    if config.score_model.use_fp16:
        model.convert_to_fp16()
    model.eval()
    
    
    classifier = create_anti_causal_predictor(config)
    
    if noised:
        classifier.load_state_dict(
            dist_util.load_state_dict(config.sampling.classifier_path, map_location=dist_util.dev())
        )
        
    else:
        classifier.load_state_dict(
            dist_util.load_state_dict(config.sampling.denoised_classifier_path, map_location=dist_util.dev())
        )        
    
    classifier.to(dist_util.dev())
    if config.classifier.classifier_use_fp16:
        classifier.convert_to_fp16()
        
    classifier.eval()
    
    return classifier, diffusion, model


def get_models_functions(config, model, anti_causal_predictor, reg_or_class='class', noised=True, pred=False, reconstruction=False,**kwargs):
    def cond_fn(x, t, age=None, **kwargs):
        assert age is not None
        with th.enable_grad():
            # (100, 1, 92, 92)
            x_in = x.detach().requires_grad_(True)
            out = anti_causal_predictor(x_in, t)

            if isinstance(out, Dict):
                logits = out[config.sampling.label_of_intervention]
            else:
                logits = out
                
            if t[0]<3 or t[0]>997:
                print(f'logits:{th.argmax(logits[0])}, ')
            # logits = out[label],  ([100,382])
            # log_probs: ([100,382])
            log_probs = F.log_softmax(logits, dim=-1)

            # selected shape: ([100])
            selected = log_probs[range(len(logits)), age.view(-1)]# converts to 1 dim  log(y|x)
            # grad shape: ([100, 1, 92, 92])
            grad_log_conditional = th.autograd.grad(selected.sum(), x_in)[0]
            
            return grad_log_conditional * config.sampling.classifier_scale, th.argmax(logits[0])
        
    def reg_fn(x, t, age=None, pred=False, **kwargs):
        
        with th.enable_grad():
            # (100, 1, 92, 92)
            x_in = x.detach().requires_grad_(True)

            out = anti_causal_predictor(x_in, t)
            if pred:
                return out
            if isinstance(out, Dict):
                logits = out[config.samplin().label_of_intervention]
            else:
                logits = out#[100,1]
            reg_diff = (age[...,None] - logits)
            reg_diff = th.mul(reg_diff, (1+th.sqrt(th.abs(reg_diff)))) #* (th.log(1000-t)[0])
            #([30, 1])
            reg_diff  = reg_diff[...,None][...,None]#[100,1]x
            #([30, 1, 1, 1])
            #reg_diff = interpolate(reg_diff, size=92,mode='bilinear')
            
            grad_log_reg =  th.mul(reg_diff,th.autograd.grad(out.sum(), x_in)[0])
            #grad_log_reg = th.autograd.grad(out.sum(), x_in)[0]
            scale = config.sampling.regressor_scale
            if reconstruction:
                scale = 0
            return  scale * grad_log_reg, logits

    def model_fn(x_t, ts, age=None, conditioning_x=None, **kwargs):
        return model(x_t, ts, age=age, conditioning_x=conditioning_x)

    
    def clamp_to_spatial_quantile(x):
        p = 0.99
        b, c, *spatial = x.shape
        quantile = th.quantile(th.abs(x).view(b,c,-1), p, dim = -1, keepdim =True)
        quantile = th.max(quantile,th.ones_like(quantile))
        quantile_broadcasted, _ = th.broadcast_tensors(quantile.unsqueeze(-1),x)
        return th.min(th.max(x,-quantile_broadcasted), quantile_broadcasted) / quantile_broadcasted
    
    
    if reg_or_class == 'class':
        return cond_fn, model_fn, clamp_to_spatial_quantile
    else:
        return reg_fn, model_fn, clamp_to_spatial_quantile
