
import copy

import torch


import torch
from torchvision import transforms, datasets
import torch.nn as nn
from torch import optim as optim
from matplotlib import pyplot as plt
import torch.nn.functional as F
import seaborn as sns
from torch.utils.data import DataLoader
from torchvision import transforms
import seaborn as sns

from torchvision import transforms, datasets
from torch import optim as optim
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader
from torchvision import transforms

from ModularUtils import ControllerConstants
from ModularUtils.ControllerConstants import map_fill_to_discrete, get_multiple_labels_fill
from ModularUtils.Discriminators import ControllerDiscriminator
from ModularUtils.Generators import DigitImageGenerator, ControllerGenerator


def get_generators(Exp, load_which_models):
    label_generators = {}
    optimizersMech = {}

    for label in Exp.Observed_DAG:

        noise_dims = Exp.NOISE_DIM + Exp.CONF_NOISE_DIM * len(
            Exp.latent_conf[label])

        parent_dims = 0
        for par in Exp.Observed_DAG[label]:
            parent_dims += Exp.label_dim[par]["feature"]

        if label in Exp.image_labels:
            # GImage_input_dim = Exp.IMAGE_NOISE_DIM + Exp.CONF_NOISE_DIM * len(Exp.latent_conf[label])
            GImage_input_dim = Exp.IMAGE_NOISE_DIM
            GImage_output_dim = 3
            num_filters = Exp.IMAGE_FILTERS
            label_generators[label] = DigitImageGenerator(noise_dim=GImage_input_dim,
                                                          conf_dim=Exp.CONF_NOISE_DIM * len(Exp.latent_conf[label]),
                                                          parent_dims=parent_dims,  #Todo: changed here.
                                                          num_filters=num_filters,
                                                          output_dim=GImage_output_dim).to(Exp.DEVICE)  # mnistImage

        else:
            label_generators[label] = ControllerGenerator(Exp, input_dim=noise_dims + parent_dims,
                                                          feature_dim=Exp.label_dim[label]["feature"],
                                                          ).to(Exp.DEVICE)  # mnistImage

        optimizersMech[label] = torch.optim.Adam(label_generators[label].parameters(), lr=Exp.learning_rate,
                                                 betas=Exp.betas,  weight_decay=Exp.generator_decay)

    # loading saved generator if required
    if True in load_which_models.values():
        gfile = Exp.LOAD_MODEL_PATH + "/checkpoints_generators/epochLast.pth"
        checkpointx = torch.load(gfile, map_location="cuda")
        # Exp.checkpoints["generator"]= checkpointx["generator"]


    for lbid, label in enumerate(Exp.label_names):
        if load_which_models[label] == True:
            # last_model= Exp.checkpoints["generator"][-1]
            last_model= checkpointx
            label_generators[label].load_state_dict(last_model[label + "state_dict"])
            optimizersMech[label].load_state_dict(last_model["optimizer" + label])
            for param_group in optimizersMech[label].param_groups:
                param_group["lr"] = Exp.learning_rate
        else:
            label_generators[label].apply(ControllerConstants.init_weights)

    return label_generators, optimizersMech


def get_discriminators(Exp, cur_mechs, load_which_models):

    discriminatorsMech=[]
    doptimizersMech=[]

    # all_vars= copy.deepcopy(cur_mechs)
    # for intv in Exp.Data_intervs:
    #     for var in intv.keys():
    #         if var in all_vars:
    #             all_vars.remove(var)

    # train_mech_dict=[{"compare": cur_mechs}, {"compare": all_vars}, {"compare": all_vars}]


    comparedim_list=[]
    for ino, intv in enumerate(Exp.Data_intervs):
        compare_Var = []
        for mech in cur_mechs:
            ret = [lb for lb in Exp.train_mech_dict[mech][ino]["compare"] if not lb in compare_Var]
            compare_Var += ret

        compare_dim = 0
        for var in compare_Var:
            compare_dim += Exp.label_dim[var]["feature"]

        comparedim_list.append(compare_dim)




    # for mechdist in train_mech_dict:
    #     compare_dim = 0
    #     for var in mechdist["compare"]:
    #         compare_dim += Exp.label_dim[var]["feature"]


    for dims in comparedim_list:
        cur_discriminator= ControllerDiscriminator(Exp, input_dim=dims).to(Exp.DEVICE)
        discriminatorsMech.append(cur_discriminator)

        doptimizersMech.append(torch.optim.Adam(cur_discriminator.parameters(), lr=Exp.learning_rate, betas=Exp.betas, weight_decay=Exp.discriminator_decay))

        checkpointx={}
        # # need to load discriminator for both observation and interventional dataset
        if True in Exp.load_which_models.values():
            dfile = Exp.LOAD_MODEL_PATH + "/checkpoints_discriminator/epochLast.pth"
            checkpointx = torch.load(dfile, map_location="cuda")

        # for lbid, label in enumerate(Exp.label_names):
            # if load_which_models[label] == True:

            var_list= "".join(x for x in cur_mechs)
            for id, _ in enumerate(discriminatorsMech):
                if "dstate_dict"+var_list+str(id) not in checkpointx:
                    continue
                discriminatorsMech[id].load_state_dict(checkpointx["dstate_dict"+var_list+str(id)])
                doptimizersMech[id].load_state_dict(checkpointx["doptimizer" + var_list+str(id)])

                for param_group in doptimizersMech[id].param_groups:
                        param_group["lr"] = Exp.learning_rate


    return discriminatorsMech, doptimizersMech


def get_generated_labels(Exp, label_generators, label_noises, conf_noises, intervened, chosen_labels, mini_batch, **kwargs):
    if not label_noises:
        for name in Exp.label_names:
            if name not in Exp.image_labels:
                label_noises[Exp.exogenous[name]] = torch.randn(mini_batch, Exp.NOISE_DIM).to(
                    Exp.DEVICE)  # white noise. no bias

    if not conf_noises:
        for label in Exp.label_names:
            confounders = Exp.latent_conf[label]
            for conf in confounders:  # no confounder name, only their sequence matters here.
                conf_noises[conf] = torch.randn(mini_batch, Exp.CONF_NOISE_DIM).to(Exp.DEVICE)  # white noise. no bias

    max_in_top_order = max([Exp.label_names.index(lb) for lb in chosen_labels])
    gen_labels = {}
    for lbid, label in enumerate(Exp.Observed_DAG):
        if lbid > max_in_top_order:  # we dont need to produce the rest of the variables.
            break
        Noises = []
        if label not in Exp.image_labels:
            Noises.append(label_noises[Exp.exogenous[label]])  # error here

        for conf in Exp.latent_conf[label]:
            Noises.append(conf_noises[conf])


        # getting observed parent values
        parent_gen_labels = []
        for parent in Exp.Observed_DAG[label]:
            parent_gen_labels.append(gen_labels[parent])

        if label in intervened.keys():
            if torch.is_tensor(intervened[label]):
                gen_labels[label] = intervened[label]
            else:
                gen_labels[label] = torch.ones(mini_batch, Exp.label_dim[label]["feature"]).to(Exp.DEVICE) * 0.00001
                gen_labels[label][:, intervened[label]] = 0.99999

        elif label in Exp.image_labels:
            Noises = []
            image_noise = torch.randn(mini_batch, Exp.IMAGE_NOISE_DIM).view(-1, Exp.IMAGE_NOISE_DIM, 1, 1).to(
                Exp.DEVICE)
            Noises.append(image_noise)
            for conf in Exp.latent_conf[label]:
                Noises.append(conf_noises[conf].view(-1, Exp.CONF_NOISE_DIM, 1, 1).to(Exp.DEVICE))

            # converting continuous fill parents to discrete fill
            parent_gen_labels = torch.cat(parent_gen_labels, 1)
            dims_list = [Exp.label_dim[lb]["feature"] for lb in Exp.Observed_DAG[label]]
            parent_gen_labels = map_fill_to_discrete(Exp, parent_gen_labels, dims_list)
            parent_gen_labels = get_multiple_labels_fill(Exp, parent_gen_labels, dims_list, isImage_labels=True, more_dimsize=1)
            gen_labels[label] = label_generators[label](Noises, [parent_gen_labels])

        else:
            gn=None
            hard= False
            if "gumbel_noise" in kwargs:
                gn=kwargs["gumbel_noise"][label]
            if "hard" in kwargs:
                hard= kwargs["hard"]
            gen_labels[label] = label_generators[label](Exp, Noises, parent_gen_labels, gumbel_noise=gn, hard=hard)

    return_labels = {}
    for label in chosen_labels:
        return_labels[label] = gen_labels[label]

    return return_labels
