import copy

import torch

from CausalMNISTAddition.DigitImageGeneration.mnist_image_generation import produce_uniform_images
from CausalMNISTAddition.mnistDiscriminators import DigitImageDiscriminator, ControllerDiscriminator
from CausalMNISTAddition.mnistGenerators import DigitImageGenerator, ControllerGenerator, ClassificationNet
from CausalTwoDiscrimMechTrain import ControllerConstants
import torch
from torchvision import transforms, datasets
import torch.nn as nn
from torch import optim as optim
from matplotlib import pyplot as plt
import torch.nn.functional as F
import seaborn as sns
from torch.utils.data import DataLoader
from torchvision import transforms
import seaborn as sns
from pathlib import Path
from numpy import uint8

from torchvision import transforms, datasets
from torch import optim as optim
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader
from torchvision import transforms

from CausalTwoDiscrimMechTrain.ConstantFunctions import load_checkpoint, load_checkpointed_generators, \
    load_checkpointed_discriminators, get_training_variables
from CausalTwoDiscrimMechTrain.ControllerConstants import get_label_fill, map_fill_to_discrete, \
    get_multiple_labels_fill, map_dictfill_to_discrete
from Image_Backdoor_Training.celebaVAE import VAE
from Image_Backdoor_Training.imageVae import DeepAutoencoder


def get_generators(Exp, load_which_models):
    label_generators = {}
    optimizersMech = {}

    for label in Exp.Observed_DAG:

        noise_dims = Exp.NOISE_DIM + Exp.CONF_NOISE_DIM * len(
            Exp.latent_conf[label])

        parent_dims = 0
        for par in Exp.Observed_DAG[label]:
            parent_dims += Exp.label_dim[par]

        if label in Exp.image_labels:
            # GImage_input_dim = Exp.IMAGE_NOISE_DIM + Exp.CONF_NOISE_DIM * len(Exp.latent_conf[label])
            GImage_input_dim = Exp.IMAGE_NOISE_DIM
            GImage_output_dim = 3
            num_filters = Exp.IMAGE_FILTERS
            label_generators[label] = DigitImageGenerator(noise_dim=GImage_input_dim,
                                                          conf_dim=Exp.CONF_NOISE_DIM * len(Exp.latent_conf[label]),
                                                          parent_dims=parent_dims,  #Todo: changed here.
                                                          num_filters=num_filters,
                                                          output_dim=GImage_output_dim).to(Exp.DEVICE)  # mnistImage

            optimizersMech[label] = torch.optim.Adam(label_generators[label].parameters(), lr=Exp.learning_rate,
                                                     betas=Exp.betas, weight_decay=Exp.generator_decay)

        elif label in Exp.rep_labels:
            # Instantiating the model and hyperparameters
            label_generators[label] = DeepAutoencoder(Exp, parent_dims, latent_dim=100).to(Exp.DEVICE)
            # label_generators[label]= VAE(parent_dims, nc=3, ngf=64, ndf=64, latent_variable_size=500).to(Exp.DEVICE)
            optimizersMech[label] = torch.optim.Adam(label_generators[label].parameters(), lr=Exp.learning_rate,
                                                     betas=Exp.betas,  weight_decay=Exp.generator_decay)

        elif set(Exp.Observed_DAG[label]) & set(Exp.image_labels) != set():
            label_generators[label] = ClassificationNet(output_dim=Exp.label_dim[label]).to(Exp.DEVICE)
            momentum = 0.5
            optimizersMech[label] = optim.SGD(label_generators[label].parameters(), lr=Exp.learning_rate,
                                              momentum=momentum)
            # rep_dim=10
            # label_generators[label] = ControllerGenerator(Exp, input_dim=noise_dims + parent_dims+ rep_dim,feature_dim=Exp.label_dim[label],).to(Exp.DEVICE)
            # optimizersMech[label] = torch.optim.Adam(label_generators[label].parameters(), lr=Exp.learning_rate,
            #                                          betas=Exp.betas, weight_decay=Exp.generator_decay)

        else:
            label_generators[label] = ControllerGenerator(Exp, input_dim=noise_dims + parent_dims,
                                                          feature_dim=Exp.label_dim[label],
                                                          ).to(Exp.DEVICE)  # mnistImage

            optimizersMech[label] = torch.optim.Adam(label_generators[label].parameters(), lr=Exp.learning_rate,
                                                     betas=Exp.betas,  weight_decay=Exp.generator_decay)

    # loading saved generator if required
    if True in load_which_models.values():
        gfile = Exp.LOAD_MODEL_PATH + "/checkpoints_generators/epochLast.pth"
        checkpointx = torch.load(gfile, map_location="cuda")
        # Exp.checkpoints["generator"]= checkpointx["generator"]


    for lbid, label in enumerate(Exp.label_names):
        if load_which_models[label] == True:
            # last_model= Exp.checkpoints["generator"][-1]
            last_model= checkpointx
            label_generators[label].load_state_dict(last_model[label + "state_dict"])
            optimizersMech[label].load_state_dict(last_model["optimizer" + label])
            for param_group in optimizersMech[label].param_groups:
                param_group["lr"] = Exp.learning_rate

            print(f'{label} generator loaded')
        else:
            label_generators[label].apply(ControllerConstants.init_weights)

    return label_generators, optimizersMech


def get_discriminators(Exp, cur_hnodes, load_which_models):

    discriminatorsMech={}
    doptimizersMech={}


    # comparedim_list=[] #for each interventional dataset
    for hnode, cur_mechs in cur_hnodes.items():

        discriminatorsMech[hnode] = []
        doptimizersMech[hnode] = []
        for ino, intv in enumerate(Exp.Data_intervs):
            all_compare_Var, compare_Var, intervened_Var, real_labels_vars = get_training_variables(Exp, cur_mechs, ino, intv)

            compare_dims = 0
            for var in real_labels_vars:
                compare_dims += Exp.label_dim[var]

            # comparedim_list.append(compare_dims)

            # flag2=
            if set(cur_mechs) & set(Exp.image_labels)  != set() :
                D_input_dim = 3
                D_output_dim = 1
                num_filters = Exp.IMAGE_FILTERS
                cur_discriminator = DigitImageDiscriminator(
                    image_dim=D_input_dim,
                    label_dims=compare_dims,
                    num_filters=num_filters[::-1],
                    output_dim=D_output_dim
                ).to(Exp.DEVICE)
            else:
                rep_dim=0
                if set(all_compare_Var) & set(Exp.rep_labels) != set():
                    rep_dim = 10
                cur_discriminator= ControllerDiscriminator(Exp, input_dim=compare_dims+ rep_dim).to(Exp.DEVICE)

            discriminatorsMech[hnode].append(cur_discriminator)
            doptimizersMech[hnode].append(torch.optim.Adam(cur_discriminator.parameters(), lr=Exp.learning_rate, betas=Exp.betas, weight_decay=Exp.discriminator_decay))


    # saving all discriminators
    # # need to load discriminator for both observation and interventional dataset
    if True in Exp.load_which_models.values():
        dfile = Exp.LOAD_MODEL_PATH + "/checkpoints_discriminator/epochLast.pth"

        if Path(dfile).is_file():
            checkpointx = torch.load(dfile, map_location="cuda")

        # for lbid, label in enumerate(Exp.label_names):
            # if load_which_models[label] == True:
            var_list= "".join(x for x in cur_mechs)
            for id, _ in enumerate(discriminatorsMech):
                if "dstate_dict"+var_list+str(id) not in checkpointx:
                    continue
                discriminatorsMech[id].load_state_dict(checkpointx["dstate_dict"+var_list+str(id)])
                doptimizersMech[id].load_state_dict(checkpointx["doptimizer" + var_list+str(id)])

                for param_group in doptimizersMech[id].param_groups:
                        param_group["lr"] = Exp.learning_rate
        else:
            print("No discriminator loaded")

    return discriminatorsMech, doptimizersMech


def get_generated_labels(Exp, label_generators, label_noises, conf_noises, intervened, chosen_labels, mini_batch, **kwargs):
    if not label_noises:
        for name in Exp.label_names:
            if name not in Exp.image_labels:
                label_noises[Exp.exogenous[name]] = torch.randn(mini_batch, Exp.NOISE_DIM).to(
                    Exp.DEVICE)  # white noise. no bias

    if not conf_noises:
        for label in Exp.label_names:
            confounders = Exp.latent_conf[label]
            for conf in confounders:  # no confounder name, only their sequence matters here.
                conf_noises[conf] = torch.randn(mini_batch, Exp.CONF_NOISE_DIM).to(Exp.DEVICE)  # white noise. no bias

    max_in_top_order = max([Exp.label_names.index(lb) for lb in chosen_labels])
    # print("max_in_top_order", max_in_top_order)
    gen_labels = {}
    for lbid, label in enumerate(Exp.Observed_DAG):
        if lbid > max_in_top_order:  # we dont need to produce the rest of the variables.
            break

        # print(lbid, label)
        Noises = []
        if label not in Exp.image_labels:
            Noises.append(label_noises[Exp.exogenous[label]])  # error here

        for conf in Exp.latent_conf[label]:
            Noises.append(conf_noises[conf])


        # getting observed parent values
        parent_gen_labels = []
        for parent in Exp.Observed_DAG[label]:
            parent_gen_labels.append(gen_labels[parent])

        if label in intervened.keys():
            if torch.is_tensor(intervened[label]):
                gen_labels[label] = intervened[label]
            else:
                gen_labels[label] = torch.ones(mini_batch, Exp.label_dim[label]).to(Exp.DEVICE) * 0.00001
                gen_labels[label][:, intervened[label]] = 0.99999

        elif label in Exp.image_labels:

            if 'true_scm' in kwargs and kwargs['true_scm']==True:  #producing images from function
                parent_gen_labels = torch.tensor(map_dictfill_to_discrete(Exp, {par:gen_labels[par] for par in Exp.Observed_DAG[label]} , Exp.Observed_DAG[label])).to(Exp.DEVICE)
                gen_image= produce_uniform_images(Exp, 0, parent_gen_labels, mini_batch , True )

                transform = transforms.Compose([transforms.ToPILImage(), transforms.ToTensor()])
                digit_images = [torch.unsqueeze(transform(img.astype(uint8)), dim=0).to(Exp.DEVICE) for img in gen_image]
                gen_labels[label]= torch.cat(digit_images, 0)
                continue

            Noises = []
            image_noise = torch.randn(mini_batch, Exp.IMAGE_NOISE_DIM).view(-1, Exp.IMAGE_NOISE_DIM, 1, 1).to(
                Exp.DEVICE)
            Noises.append(image_noise)
            for conf in Exp.latent_conf[label]:
                Noises.append(conf_noises[conf].view(-1, Exp.CONF_NOISE_DIM, 1, 1).to(Exp.DEVICE))

            # converting continuous fill parents to discrete fill
            parent_gen_labels = torch.cat(parent_gen_labels, 1)
            dims_list = [Exp.label_dim[lb] for lb in Exp.Observed_DAG[label]]
            parent_gen_labels = map_fill_to_discrete(Exp, parent_gen_labels, dims_list)
            parent_gen_labels = get_multiple_labels_fill(Exp, parent_gen_labels, dims_list, isImage_labels=True, more_dimsize=1)
            gen_labels[label] = label_generators[label](Noises, [parent_gen_labels])

        # elif set(Exp.Observed_DAG[label]) & set(Exp.rep_labels) != set():
        elif label in Exp.rep_labels:
            # gen_labels[label] = label_generators[label](Exp, parent_gen_labels, gumbel_noise=None, hard=False)
            img= parent_gen_labels[0]
            label_data= parent_gen_labels[1]
            par= Exp.Observed_DAG[label][1]
            dim_list= [Exp.label_dim[par]]
            gen_labels[label] = label_generators[label](Exp, img, label_data, dim_list,  isOnehot=True, isLatent=True)

        elif set(Exp.Observed_DAG[label]) & set(Exp.image_labels) != set():
            gen_labels[label] = label_generators[label](Exp, parent_gen_labels, gumbel_noise=None, hard=False)
        else:
            gn=None
            hard= False
            if "gumbel_noise" in kwargs:
                gn=kwargs["gumbel_noise"][label]
            if "hard" in kwargs:
                hard= kwargs["hard"]
            gen_labels[label] = label_generators[label](Exp, Noises, parent_gen_labels, gumbel_noise=gn, hard=hard)

    return_labels = {}
    for label in chosen_labels:
        return_labels[label] = gen_labels[label]

    return return_labels
