import os
from argparse import Namespace
from pdb import set_trace

import attr
import matplotlib

from peagang.models.components.generators.ds.models import PointNetSTGen
from peagang.models.components.utilities_classes import Swish

matplotlib.use("Agg")

from peagang.utils.grad_flow import plot_grad_flow
from peagang.utils.grad_penalty import compute_gradient_penalty

import torch
import torch as pt
import numpy as np
import torch.distributions as ptd
from pytorch_lightning import LightningModule

from torch.utils.data import DataLoader

from peagang.optim import ExtraAdam
from peagang.models.components.generators.att.models import Attention_generator
from peagang.models.components.generators.mlp_row.models import MLP_generator
from peagang.models.components.discriminators.kCycleGIN import Discriminator

from peagang.data.dense.PEAWGANDenseData import PEAWGANDenseData
from peagang.data.dense.PEAWGANDenseStructureData import PEAWGANDenseStructureData

from torchvision.utils import make_grid
import math

def forward_clip_hook(module, input, output):
    if pt.is_tensor(output):
        outs=[output]
    else:
        outs=output
    new_o=[]
    for i,o in enumerate(outs):
        if pt.isinf(o).any():
            warn(f"Found {pt.isinf(o).sum().item()} infs in output of{module},clamping")
            new_o.append( pt.clamp(output,-1e10,1e10))
        else:
            new_o.append(o)
    if pt.is_tensor(output):
        return new_o[0]
    else:
        return new_o
def backward_trace_hook(module,  grad_input, grad_output):
    _grad_input=grad_input
    if pt.is_tensor(grad_input):
        grad_input=[grad_input]
    if pt.is_tensor(grad_output):
        grad_input=[grad_output]
    for i in grad_input:
        if pt.is_tensor(i) and pt.isnan(i).any():
            set_trace()
    for o in grad_output:
        if pt.is_tensor(i) and pt.isnan(o).any():
            set_trace()
    return _grad_input
def backward_trace_hook_t(grad_input):
    _grad_input=grad_input
    if pt.is_tensor(grad_input):
        grad_input=[grad_input]
    for i in grad_input:
        if pt.is_tensor(i) and pt.isnan(i).any():
            set_trace()
    return _grad_input
def backward_clean_hook(module,  grad_input, grad_output):
    _grad_input=grad_input
    _grad_output=grad_output
    if pt.is_tensor(grad_input):
        grad_input=[grad_input]
    if pt.is_tensor(grad_output):
        grad_input=[grad_output]
    gi_out=[]
    go_out=[]
    for i in grad_input:
        if pt.is_tensor(i):
            in_nans=pt.isnan(i)
            if in_nans.any():
                i=i.clone()
                i[in_nans]=0.0
        gi_out.append(i)
    for o in grad_output:
        if pt.is_tensor(o):
            out_nans=pt.isnan(o)
            if out_nans.any():
                o=o.clone()
                o[out_nans]=0.0
        go_out=o
    if pt.is_tensor(_grad_input):
        gi_out=gi_out[0]
    if pt.is_tensor(_grad_output):
        go_out=go_out[0]
    return gi_out


def ensure_tensor(x):
    if pt.is_tensor(x):
        return x
    elif isinstance(x,np.ndarray):
        return pt.from_numpy(x)
    else:
        return pt.tensor(x)


class PEAWGAN(LightningModule):
    def __init__(self, hparams, *args, **kwargs):
        super().__init__()
        self.Z0 = None
        self.finetti_u = None
        self.hparams = hparams
        self.hpars=PEAWGAN_HyperParameters.from_dict(hparams)
        self.num_node_dist = ptd.Categorical(ensure_tensor(self.hpars.node_count_weights))
        # We want the device because of the plotting as sometimes Z and the context vector come from different.
        if self.hpars.architecture=="attention":
            self.generator = Attention_generator(
                self.hpars.embed_dim,
                self.hpars.node_feature_dim,
                self.num_node_dist,
                finetti_dim=self.hpars.finetti_dim,
                inner_activation=None,
                out_activation=None,
                edge_readout_type=self.hpars.edge_readout,
                attention_mode=self.hpars.attention_mode,
                score_function=self.hpars.score_function,
                discretization=self.hpars.discretization,
                temperature=self.hpars.temperature,
                n_attention_layers=self.hpars.n_attention_layers,
                num_heads=self.hpars.num_heads,
                cycle_opt=self.hpars.cycle_opt,
                seed_batch_size=self.hpars.batch_size,
                trainable_z=self.hpars.finetti_trainable,
                train_fix_context=self.hpars.finetti_train_fix_context,
                flip_finetti=self.hpars.flip_finetti,
                dynamic_creation=self.hpars.dynamic_finetti_creation,
                finneti_MLP=self.hpars.finneti_MLP,
                replicated_Z=self.hpars.replicated_Z,
                bias_mode=self.hpars.edge_bias_mode
            )
        elif self.hpars.architecture=="mlp_row":
            self.generator = MLP_generator(
                self.hpars.embed_dim,
                self.hpars.node_feature_dim,
                self.num_node_dist,
                finetti_dim=self.hpars.finetti_dim,
                edge_readout_type=self.hpars.edge_readout,
                layers_=self.hpars.MLP_layers,
                discretization=self.hpars.discretization,
                temperature=self.hpars.temperature,
                cycle_opt=self.hpars.cycle_opt,
                seed_batch_size=self.hpars.batch_size,
                trainable_z=self.hpars.finetti_trainable,
                train_fix_context=self.hpars.finetti_train_fix_context,
                flip_finetti=self.hpars.flip_finetti,
                dynamic_creation=self.hpars.dynamic_finetti_creation,
                finneti_MLP=self.hpars.finneti_MLP,
                replicated_Z=self.hpars.replicated_Z,
            )
        elif self.hpars.architecture=="deepset":
            self.generator=PointNetSTGen(
                self.hpars.embed_dim,
                self.hpars.node_feature_dim,
                self.num_node_dist,
                finetti_dim=self.hpars.finetti_dim,
                inner_activation=Swish(),
                out_activation=None,
                edge_readout_type=self.hpars.edge_readout,
                attention_mode=self.hpars.attention_mode,
                score_function=self.hpars.score_function,
                discretization=self.hpars.discretization,
                temperature=self.hpars.temperature,
                n_set_layers=self.hpars.n_attention_layers,
                num_heads=self.hpars.num_heads,
                cycle_opt=self.hpars.cycle_opt,
                seed_batch_size=self.hpars.batch_size,
                trainable_z=self.hpars.finetti_trainable,
                train_fix_context=self.hpars.finetti_train_fix_context,
                flip_finetti=self.hpars.flip_finetti,
                dynamic_creation=self.hpars.dynamic_finetti_creation,
                finneti_MLP=self.hpars.finneti_MLP,
                replicated_Z=self.hpars.replicated_Z,
                bias_mode=self.hpars.edge_bias_mode
            )
        else:
            raise NotImplementedError("Still need to add some architectures here, still missing deepset,mlp,rnn")

        self.discriminator = Discriminator(
            self.hpars.node_feature_dim,
            self.hpars.disc_conv_channels,
            kc_flag=self.hpars.kc_flag,
            readout_hidden=self.hpars.disc_readout_hidden, swish=False,
        )

        # TODO remove this and make in the dataset automatic one hot or not
        self.dense_b_dst = None

        self.G_path = os.path.join(os.getcwd(), self.hpars.save_dir, self.hpars.model_n)
        os.makedirs(self.G_path, exist_ok=True)
        os.makedirs(os.path.join(self.G_path, "embeddings/Z0/"), exist_ok=True)
        os.makedirs(os.path.join(self.G_path, "embeddings/finetti_u/"), exist_ok=True)

    def forward(self, real_data=None, device=None, *args, **kwargs):
        if device is None:
            device=self.device
        if real_data is not None:
            real_node_feat, realA = real_data
            num_nodes = real_node_feat[:, 0, -1]
            realA=realA.float()
            fake_nodes, fake_adj, _, _, _ = self.generator.sample(
                batch_size=real_node_feat.shape[0], N=num_nodes, device=device
            )
            if self.hpars.disc_contrast=="fake-struct_fake":
                fake_nodes=self.train_set.get_structural_node_features(fake_adj)[0]
            elif self.hpars.disc_contrast=="real_fake":
                fake_nodes=real_node_feat
            elif self.hpars.disc_contrast=="fake_fake":
                fake_nodes=fake_nodes
            assert fake_nodes.shape[-1]==real_node_feat.shape[-1]
            fake_score = self.discriminator(fake_nodes, fake_adj)
            real_score = self.discriminator(real_node_feat,realA)
            return fake_nodes, fake_adj, fake_score, real_score
        else:
            return self.sample(device=device)
    def sample(self,device=None,batch_size=None):
        if device is None:
            device=self.device
        if batch_size is None:
            batch_size=self.hpars.batch_size
        else:
            assert batch_size<=self.hpars.batch_size
        fake_nodes, fake_adj, Z0, finetti_u, Q = self.generator.sample(
            batch_size=batch_size, device=device
        )
        if self.hpars.disc_contrast == "fake-struct_fake":
            fake_nodes = self.train_set.get_structural_node_features(fake_adj)[0]
        if self.hpars.disc_contrast=="fake_fake":
            assert fake_nodes.requires_grad
        return fake_nodes, fake_adj,  Z0, finetti_u, Q

    def training_step(self, batch, batch_idx, optimizer_idx):
        if self.trainer.total_batch_idx == 0:
            os.makedirs(os.path.join(self.G_path, "version_0/checkpoints/"),exist_ok=True)
            self.trainer.save_checkpoint(
                os.path.join(self.G_path, "version_0/checkpoints/", "epoch=-1.ckpt")
            )
            print("Saved untrained model \n")

        # give the input device to sample in generator case
        input_device = batch[0].device
        assert input_device==self.device
        if optimizer_idx == 0: # train discriminator
            fake_nodes, fake_adj, fake_score, real_score = self.forward(
                real_data=batch, device=input_device
            )
            if self.training:
                fake_score.register_hook(backward_trace_hook_t)
            real_score = real_score.mean()
            fake_score = fake_score.mean()
            W1 = real_score - fake_score
            grad_penalty = compute_gradient_penalty(
                self.discriminator,
                batch,
                (fake_nodes, fake_adj),
                LP=self.hpars.LP,
            )
            loss = -W1 + self.hpars.penalty_lambda * grad_penalty
            if self.hpars.score_penalty>0.0:
                loss=loss+ self.hpars.score_penalty*(real_score.norm()+fake_score.norm())
            ret = dict(
                loss=loss,
                log=dict(
                    grad_penalty=grad_penalty,
                    W1=W1,
                    disc_loss=loss,
                    real_score=real_score,
                    fake_score=fake_score,
                ),
            )
            if batch_idx % self.hpars.grid_every == 0:
                grid_real = make_grid(
                    batch[1].unsqueeze(1),
                    nrow=int(math.ceil(math.sqrt(fake_nodes.shape[0]))),
                )
                self.logger.experiment.add_image(
                    "real_graphs", grid_real, self.trainer.total_batch_idx
                )
                grid_fake = make_grid(
                    fake_adj.unsqueeze(1),
                    nrow=int(math.ceil(math.sqrt(fake_nodes.shape[0]))),
                )
                self.logger.experiment.add_image(
                    "generated_graphs", grid_fake, self.trainer.total_batch_idx
                )
                self.log_weights()
        elif optimizer_idx == 1: # train generator
            if batch_idx > 0 and batch_idx % self.hpars.generator_every == 0:
                fake_nodes, _fakeA,  Z0, finetti_u, _ = self.sample(device=input_device)
                fake_adj=_fakeA
                if self.hpars.disc_contrast=="real_fake":
                    fake_nodes=batch[0]
                assert fake_nodes.shape[-1]==batch[0].shape[-1]
                if self.training:
                    if self.hpars.disc_contrast=="fake_fake":
                        fake_nodes.register_hook(backward_trace_hook_t)
                    fake_adj.register_hook(backward_trace_hook_t)
                fake_score = self.discriminator(fake_nodes, fake_adj)
                loss = -fake_score.mean()
                ret = dict(loss=loss, log=dict(gen_loss=loss))
                self.Z0 = Z0
                self.finetti_u = finetti_u
            else:
                ret = dict(loss=pt.zeros((), requires_grad=True))

        return ret

    def log_weights(self,hist=False):
        norms={}
        nmax=0.0
        for name,param in self.named_parameters():
            n=pt.norm(param)
            if n>nmax:
                nmax=n
            self.logger.experiment.add_scalar(f"{name}_fro",n,global_step=self.trainer.total_batch_idx)
            if hist:
                self.logger.experiment.add_histogram(f"{name}_hist",param)
        self.logger.experiment.add_scalar(f"W_fro_max",nmax,global_step=self.trainer.total_batch_idx)

    def optimizer_step(
        self, epoch, batch_idx, optimizer, optimizer_idx, second_order_closure=None,**kwargs
    ):
        if epoch == 0 and batch_idx == int(len(self.train_set) / self.hpars.batch_size):
            pt.save(self.Z0, os.path.join(self.G_path, "embeddings/Z0/Z0_" + str(epoch).zfill(4) + ".pt"))
            pt.save(self.finetti_u, os.path.join(self.G_path, "embeddings/finetti_u/finetti_" + str(epoch).zfill(4) + ".pt"))
        elif (epoch) % 50 == 0 and batch_idx == int(len(self.train_set) / self.hpars.batch_size) - 1\
                and (epoch + 1) <= 300:
            pt.save(self.Z0, os.path.join(self.G_path, "embeddings/Z0/Z0_" + str(epoch).zfill(4) + ".pt"))
            pt.save(self.finetti_u, os.path.join(self.G_path, "embeddings/finetti_u/finetti_" + str(epoch).zfill(4) + ".pt"))
        elif (epoch) % 100 == 0 and batch_idx == int(len(self.train_set) / self.hpars.batch_size) - 1 \
                and (epoch + 1) > 300:
            pt.save(self.Z0, os.path.join(self.G_path, "embeddings/Z0/Z0_" + str(epoch).zfill(4) + ".pt"))
            pt.save(self.finetti_u, os.path.join(self.G_path, "embeddings/finetti_u/finetti_" + str(epoch).zfill(4) + ".pt"))

        if optimizer_idx == 0:
            optimizer.step()
            if (
                epoch % 50 == 0
                and batch_idx
                == int(len(self.train_set) / self.hpars.batch_size) - 1
            ):
                plot_grad_flow(
                    self.discriminator.named_parameters(),
                    epoch,
                    self.hpars.save_dir,
                    self.hpars.model_n,
                )
            optimizer.zero_grad()
        elif (
            optimizer_idx == 1
            and batch_idx % self.hpars.generator_every == 0
            and batch_idx > 0
        ):
            optimizer.step()
            if epoch % 50 == 0 and batch_idx == 15:
                plot_grad_flow(
                    self.generator.named_parameters(),
                    epoch,
                    self.hpars.save_dir,
                    self.hpars.model_n,
                    g_=True,
                )
            optimizer.zero_grad()

    def configure_optimizers(self):
        # optimizers
        gen_opt = ExtraAdam(
            self.generator.parameters(), **self.hpars.gen_optim_args
        )
        disc_opt = ExtraAdam(
            self.discriminator.parameters(), **self.hpars.disc_optim_args
        )

        from torch.optim.lr_scheduler import StepLR

        disc_sched, gen_sched = [
            StepLR(
                x,
                step_size=self.hpars.reduce_every,
                gamma=self.hpars.lr_gamma,
            )
            for x in [disc_opt, gen_opt]
        ]
        return [disc_opt, gen_opt], [disc_sched, gen_sched]

    def prepare_data(self):
        if self.hpars.structured_features:
            self.train_set = PEAWGANDenseStructureData(
                data_dir=self.hpars.data_dir,
                filename=self.hpars.filename,
                dataset=self.hpars.dataset,
                k_eigenvals=self.hpars.k_eigenvals,
                use_laplacian=self.hpars.use_laplacian,
                large_N_approx=self.hpars.large_N_approx,
                inner_kwargs=self.hpars.dataset_kwargs,
                cut_train_size=self.hpars.cut_train_size,
                zero_pad=True
            )
        else:
            self.train_set = PEAWGANDenseData(
                data_dir=self.hpars.data_dir,
                filename=self.hpars.filename,
                dataset=self.hpars.dataset,
                inner_kwargs=self.hpars.dataset_kwargs,
                one_hot=self.hpars.label_one_hot,
                cut_train_size=self.hpars.cut_train_size,
                zero_pad=True
            )


    def train_dataloader(self):
        dl = DataLoader(
            self.train_set,
            batch_size=self.hpars.batch_size,
            shuffle=self.hpars.shuffle,
            num_workers=self.hpars.num_workers,
            pin_memory=True,
        )
        return dl


@attr.s
class PEAWGAN_HyperParameters:
    # old parameters
    node_count_weights = attr.ib()
    model_n = attr.ib(default="GG-GAN")
    base_dir = attr.ib(default=".")
    data_dir = attr.ib(default=".")
    filename = attr.ib(default=".")
    # Deep toggle...can be removed after further fixes
    deep = attr.ib(
        default=False
    )  # kept for compatability, deep_disc/gen overrides this if set
    deep_disc = attr.ib(default=None)
    deep_gen = attr.ib(default=None)
    # hyper parameters
    dataset = attr.ib(
        default="MolGAN_5k"
    )  # MolGAN, MolGAN_5k, MolGAN_kC4, MolGAN_kC5, MolGAN_kC6, CommunitySmall_12, CommunitySmall_20 anu_graphs_chordal9
    cut_train_size = attr.ib(default=False)
    batch_size = attr.ib(default=20)
    shuffle = attr.ib(default=True)
    # separate disc and gen learning rates for TTUR https://github.com/bioinf-jku/TTUR/blob/master/WGAN_GP/gan_64x64_FID.py : disc 3e-4 G 1e-4 beta=0,0.9
    # ema motivated by https://arxiv.org/pdf/1806.04498.pdf
    architecture=attr.ib(default="attention",validator=attr.validators.in_({
        "attention",
        "mlp_row",
        "deepset",
        "rnn"
    }))
    disc_optim_args = attr.ib(
        factory=lambda: dict(
            lr=1e-4,
            betas=(0.5, 0.9999),
            eps=1e-8,
            weight_decay=1e-3,
            ema=False,
            ema_start=100,
        )
    )  # recommended setting from WGAN-GP/optimistic Adam:wq paper, half the learning rate tho
    gen_optim_args = attr.ib(
        factory=lambda: dict(
            lr=1e-4,
            betas=(0.5, 0.9999),
            eps=1e-8,
            weight_decay=1e-3,
            ema=False,
            ema_start=100,
        )
    )  # recommended setting from WGAN-GP/optimistic Adam:wq paper, half the learning rate tho
    extra_adam = attr.ib(default=True)
    reduce_every = attr.ib(default=100)
    lr_gamma = attr.ib(default=0.1)
    embed_dim = attr.ib(default=50)
    finetti_dim = attr.ib(default=50)
    label_one_hot = attr.ib(default=5)  # same as in node feature dim...
    node_feature_dim = attr.ib(
        default=5
    )  # + (9 - 2)+1)  # number of cycles+number of nodes in graph
    kc_flag = attr.ib(default=True)
    disc_conv_channels = attr.ib(default=[32, 64, 64, 64])
    LP = attr.ib(default=True)  # leaky penalty, False,True, or "ZP" string
    penalty_lambda = attr.ib(default=5)
    penalty_onfake = attr.ib(default=False)
    penalty_onreal = attr.ib(default=False)
    generator_every = attr.ib(default=5)
    attention_mode = attr.ib(
        default="QK", validator=attr.validators.in_({"QQ", "QK"})
    )  # QQ,QK, other
    score_function = attr.ib(
        default="sigmoid", validator=attr.validators.in_({"sigmoid", "softmax"})
    )
    edge_readout = attr.ib(
        default="attention_weights",
        validator=attr.validators.in_(
            {
                "biased_sigmoid",
                "rescaled_softmax",
                "gaussian_kernel",
                "attention_weights",
                "QQ_sig",
            }
        ),
    )
    edge_bias_mode = attr.ib(
        default="scalar"
    )  # nodes/scalar for biased sigmoid, True/False for rescaled_softmax
    edge_bias_hidden = attr.ib(
        default=128
    )  # nodes/scalar for biased sigmoid, True/False for rescaled_softmax
    cycle_opt = attr.ib(
        default="finetti_noDS",
        validator=attr.validators.in_(
            {"standard", "finetti_noDS", "finetti_ds"}
        ),
    )
    score_penalty=attr.ib(default=0.0)
    disc_readout_hidden = attr.ib(default=32)
    n_attention_layers = attr.ib(default=12)
    MLP_layers = attr.ib(default=[128, 256, 512])
    num_workers = attr.ib(default=8)
    attention_inner_layers = attr.ib(type=list, factory=list)
    num_heads = attr.ib(default=1)
    discretization = attr.ib(default="relaxed_bernoulli")
    disc_swish = attr.ib(default=True)
    deep_gen_inner_act = attr.ib(default=None)  # swish, relu
    deep_gen_out_act = attr.ib(default=None)
    disc_spectral_norm = attr.ib(default=None)  # none, diff,nondiff
    disc_dropout = attr.ib(default=None)  # none, diff,nondiff
    gen_spectral_norm = attr.ib(default=None)  # none, diff,nondiff
    temperature = attr.ib(
        default=2 / 3.0
    )  # the lower the more discrete, but less smooth. Taking 2/3 form http://www.stats.ox.ac.uk/~cmaddis/pubs/concrete.pdf
    disc_contrast = attr.ib(
        default="real_fake", validator=attr.validators.in_({"real_fake", "fake_fake", "fake-struct_fake"})
    )  # real/fake node embeddings with fake adjacency matrix
    disc_penalty_mode = attr.ib(
        default="interpolate_Adj",
        validator=attr.validators.in_(
            {
                "avg_grads",
                "interpolate_Adj",
                "interpolate_emebbeddings",
                "GNN_layerwise_penatlty",
            }
        ),
    )
    save_dir=attr.ib(default="peagang_save")
    plot_lcc=attr.ib(default=True) # plot only largest connected component

    # TODO: figure out how to move away from dynamic init
    #seed_batch_shape = None,
    #seed_batch_size = None,
    #seedN = None,
    grid_every = attr.ib(default=100)
    finetti_trainable=attr.ib(default=True)
    flip_finetti = attr.ib(default=False)
    finetti_train_fix_context=attr.ib(default=False)
    dynamic_finetti_creation=attr.ib(default=False)
    replicated_Z=attr.ib(default=False)
    exp_name=attr.ib(default=None)
    device = attr.ib(default="cuda:0" if torch.cuda.is_available() else "cpu")
    finneti_MLP=attr.ib(default=False)

    node_feat_proj=attr.ib(default=None) # project the input features onto finetti_dim+embed_dim
    structured_features=attr.ib(default=False)
    k_eigenvals=attr.ib(default=4)
    use_laplacian=attr.ib(default=False)
    large_N_approx=attr.ib(default=False)
    dataset_kwargs=attr.ib(factory=dict)
    make_node_features=attr.ib(default=True)#TODO: remove, only there to ensure consistency between versions
    viz=attr.ib(default=False)
    disc_eigenfeat=attr.ib(default=False)

    @classmethod
    def from_dict(cls,d):
        if isinstance(d,Namespace):
            return cls(**d.vars())
        else:
            return cls(**d)
