import argparse
import inspect
import ml_collections
import numpy as np
import torch
from configs import get_configs
from Diffusion import gaussian_diffusion as gd
from Diffusion.respace import SpacedDiffusion, space_timesteps
from Diffusion.unet import UNetModel, EncoderUNetModel, ImageConditionalUNet
import matplotlib.pyplot as plt
from torch.nn.functional import interpolate
from Evaluation_Metrics.ssim import ssim
import torch.nn.functional as F
from PIL import Image
from collections import OrderedDict
import torch
from torch import nn, optim

from ignite.engine import *
from ignite.handlers import *
from ignite.metrics import *
from ignite.utils import *
from ignite.contrib.metrics.regression import *
from ignite.contrib.metrics import *

# create default evaluator for doctests

def eval_step(engine, batch):
    return batch

default_evaluator = Engine(eval_step)

# create default optimizer for doctests

param_tensor = torch.zeros([1], requires_grad=True)
default_optimizer = torch.optim.SGD([param_tensor], lr=0.1)

# create default trainer for doctests
# as handlers could be attached to the trainer,
# each test must define his own trainer using `.. testsetup:`

def get_default_trainer():

    def train_step(engine, batch):
        return batch

    return Engine(train_step)

# create default model for doctests

default_model = nn.Sequential(OrderedDict([
    ('base', nn.Linear(4, 2)),
    ('fc', nn.Linear(2, 1))
]))

manual_seed(666)

def create_score_model_(config: ml_collections.ConfigDict):
    return UNetModel(
        image_size=config.score_model.image_size,
        in_channels=config.score_model.num_input_channels,
        model_channels=config.score_model.num_channels,
        out_channels=(
            config.score_model.num_input_channels
            if not config.score_model.learn_sigma else 2 * config.score_model.num_input_channels),
        num_res_blocks=config.score_model.num_res_blocks,
        attention_resolutions=tuple(config.score_model.attention_ds),
        dropout=config.score_model.dropout,
        channel_mult=config.score_model.channel_mult,
        num_classes=(config.score_model.num_classes if config.score_model.class_cond else None),
        use_checkpoint=False,
        use_fp16=False,
        num_heads=config.score_model.num_heads,
        num_head_channels=config.score_model.num_head_channels,
        num_heads_upsample=config.score_model.num_heads_upsample,
        use_scale_shift_norm=config.score_model.use_scale_shift_norm,
        resblock_updown=config.score_model.resblock_updown,
        use_new_attention_order=config.score_model.use_new_attention_order,
    )


def create_image_cond_score_model(config: ml_collections.ConfigDict):
    return ImageConditionalUNet(
        image_size=config.score_model.image_size,
        in_channels=config.score_model.num_input_channels,
        model_channels=config.score_model.num_channels,
        out_channels=(
            config.score_model.num_input_channels
            if not config.score_model.learn_sigma else 2 * config.score_model.num_input_channels),
        num_res_blocks=config.score_model.num_res_blocks,
        attention_resolutions=tuple(config.score_model.attention_ds),
        dropout=config.score_model.dropout,
        channel_mult=config.score_model.channel_mult,
        num_classes=(config.score_model.num_classes if config.score_model.class_cond else None),
        use_checkpoint=False,
        use_fp16=False,
        num_heads=config.score_model.num_heads,
        num_head_channels=config.score_model.num_head_channels,
        num_heads_upsample=config.score_model.num_heads_upsample,
        use_scale_shift_norm=config.score_model.use_scale_shift_norm,
        resblock_updown=config.score_model.resblock_updown,
        use_new_attention_order=config.score_model.use_new_attention_order,
    )


def create_anti_causal_predictor(config):
    enc = []
    nb_variables = len(config.classifier.label)
    for i in range(nb_variables):
        enc.append(EncoderUNetModel(
            image_size=config.classifier.image_size,
            in_channels=config.classifier.in_channels,
            model_channels=config.classifier.classifier_width,
            out_channels=config.classifier.out_channels[i] if nb_variables == 1 else 128,
            num_res_blocks=config.classifier.classifier_depth,
            attention_resolutions=config.classifier.attention_ds,
            channel_mult=config.classifier.channel_mult,
            use_fp16=config.classifier.classifier_use_fp16,
            num_head_channels=64,
            use_scale_shift_norm=config.classifier.classifier_use_scale_shift_norm,
            resblock_updown=config.classifier.classifier_resblock_updown,
            pool=config.classifier.classifier_pool,
        ))
    if nb_variables == 1:
        model = enc[0]
    else:
        model = AntiCausalMechanism(encoders=enc, out_labels=config.classifier.label,
                                    out_channels=config.classifier.out_channels)
    return model


def create_gaussian_diffusion(config):
    betas = gd.get_named_beta_schedule(config.diffusion.noise_schedule, config.diffusion.steps)
    if config.diffusion.use_kl:
        loss_type = gd.LossType.RESCALED_KL
    elif config.diffusion.rescale_learned_sigmas:
        loss_type = gd.LossType.RESCALED_MSE
    else:
        loss_type = gd.LossType.MSE
    timestep_respacing = config.diffusion.timestep_respacing
    if not timestep_respacing:
        timestep_respacing = [config.diffusion.steps]
    return SpacedDiffusion(
        use_timesteps=space_timesteps(config.diffusion.steps, timestep_respacing),
        betas=betas,
        model_mean_type=(
            gd.ModelMeanType.EPSILON if not config.diffusion.predict_xstart else gd.ModelMeanType.START_X
        ),
        model_var_type=(
            (
                gd.ModelVarType.FIXED_LARGE
                if not config.diffusion.sigma_small
                else gd.ModelVarType.FIXED_SMALL
            )
            if not config.diffusion.learn_sigma
            else gd.ModelVarType.LEARNED_RANGE
        ),
        loss_type=loss_type,
        rescale_timesteps=config.diffusion.rescale_timesteps,
        conditioning_noise=config.diffusion.conditioning_noise
    )


def add_dict_to_argparser(parser, default_dict):
    for k, v in default_dict.items():
        v_type = type(v)
        if v is None:
            v_type = str
        elif isinstance(v, bool):
            v_type = str2bool
        parser.add_argument(f"--{k}", default=v, type=v_type)


def args_to_dict(args, keys):
    return {k: getattr(args, k) for k in keys}


def str2bool(v):
    """
    https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
    """
    if isinstance(v, bool):
        return v
    if v.lower() in ("yes", "true", "t", "y", "1"):
        return True
    elif v.lower() in ("no", "false", "f", "n", "0"):
        return False
    else:
        raise argparse.ArgumentTypeError("boolean value expected")


def evaluate(init_image, long_image, top=False):
    final_mse_org = []
    final_ssim_org = []
    final_psnr = []
    index = []
    init_image = interpolate(init_image, size=(208,160), mode='bilinear', align_corners=True)
    long_image = interpolate(long_image, size=(208,160), mode='bilinear', align_corners=True)
    for i in init_image:  # data_dict_val['image']
        mse = []
        for j in long_image:
            mse_now = F.mse_loss(i, j)
            mse.append(mse_now)
        final_mse_org.append(min(mse).cpu().numpy())
        index.append(mse.index(min(mse)))
        
    j = 0 
    for i in init_image:
        index_ = index[j]
        ssim_now = ssim( i[None,...],  long_image[index_][None,...], val_range=1,size_average=True,window_size=11)
        final_ssim_org.append(ssim_now.cpu().numpy())
        
        psnr = PSNR(data_range=1.0)
        psnr.attach(default_evaluator, 'psnr')
        preds = i
        target = long_image[index_]
        state = default_evaluator.run([[preds, target]])
        final_psnr.append(state.metrics['psnr'])
        
        j += 1
    final_mse_org= np.array(final_mse_org)
    final_ssim_org = np.array(final_ssim_org)
    final_psnr = np.array(final_psnr)
    if top:
        final_mse_org = final_mse_org[np.argsort(final_mse_org)][:7]
        final_ssim_org = final_ssim_org[np.argsort(final_ssim_org)][23:]
        final_psnr = final_psnr[np.argsort(final_psnr)][23:]
    
    return np.mean(final_mse_org), np.mean(final_ssim_org), np.mean(final_psnr)    