import argparse
import os

import torch
from torch import nn
import pytorch_lightning as pl
import numpy as np
import itertools
import functools
from collections import defaultdict

import lifted_pddl

from torch.nn.parameter import Parameter
from torch.distributions.normal import Normal # Gaussian distribution
from stable_trunc_gaussian import TruncatedGaussian as TG # Truncated gaussian distribution

from .nlm_new import NLM
# TODO, add other models

class HeuristicLearner(pl.LightningModule):
    def __init__(self, args:argparse.Namespace):
        super().__init__()

        self.register_parameter('persistent_global_step', Parameter(torch.tensor(0, dtype=int), requires_grad=False))

        self.save_hyperparameters(args)


    @staticmethod
    def id(args):
        return (
            f"{args.domain}_{args.model}_{args.dist}_{args.sigma}_{args.l_train}_"
            f"{args.u_train}_{args.l_test}_{args.u_test}_{args.res_train}_{args.res_test}_{args.l_as_input}_"
            f"{args.train_size}_{args.lr}_{args.decay}_{args.clip}_{args.batch_size}_{args.seed}"
        )


    def _to_tensor(self, val, use_grad=False):
        """
        Auxiliar method to obtain a tensor from some value
        By default, requires_grad is set to False
        """
        return torch.tensor(val, dtype=torch.float32, requires_grad=use_grad, device=self.device)

    def _clip_sigma(self, val, min=0.1, max=10):
        """
        Method for clipping the sigma value predicted by the model to some range [min,max] (approximately)
        To do so, we apply the function sigmoid(val)*max + min
        """
        return torch.sigmoid(val)*max + min


    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.decay)
        return optimizer


    def forward(self, batch, metrics=True):
        model_specific_input, target, lb, ub, res = batch

        target = target.to(device=self.device)
        lb = lb.to(device=self.device)
        ub = ub.to(device=self.device)
        res = res.to(device=self.device)

        delta_mu, sigma = self._forward(model_specific_input, lb)

        mu = delta_mu + res # note: res(idual) is always added, but its default value is res == 0
        mu_clip = torch.clip(mu,min=lb,max=ub)

        if self.hparams.dist == "gaussian":
            dist = Normal(mu, sigma)
            dist_clip = Normal(mu_clip, sigma)

        elif self.hparams.dist == "truncated":
            dist = TG(mu, sigma, lb, ub)
            dist_clip = TG(mu_clip, sigma, lb, ub)
        else:
            raise "huh?"

        results = defaultdict(dict)
        mean = dist.mean
        results["prediction"]["mean"] = mean
        mean_clip = dist_clip.mean
        results["prediction"]["mean_clip"] = mean_clip

        if self.hparams.dist == "gaussian":
            results["prediction"]["heuristic"] = mean_clip
        elif self.hparams.dist == "truncated":
            results["prediction"]["heuristic"] = mean
        else:
            raise "huh?"

        if metrics:
            results["metrics"]["nll"]      = -dist.log_prob(target).mean()
            results["metrics"]["nll_clip"] = -dist_clip.log_prob(target).mean()
            results["metrics"]["mse"]      = (mean - target).square().mean()
            results["metrics"]["mse_clip"] = (mean_clip - target).square().mean()
            results["metrics"]["mu"]      = mean.mean()
            results["metrics"]["sigma"]   = sigma.mean()

        return results


    def _forward(self, model_specific_input, lb):
        raise NotImplementedError()

    def training_step(self, batch, batch_idx=0):
        results = self.forward(batch)
        for k, v in results["metrics"].items():
            self.log(f"t_{k}", v, on_step=False, on_epoch=True, batch_size=self.hparams.batch_size, prog_bar=(k=="nll"))

        self.persistent_global_step += 1
        return results["metrics"]["nll"]

    def validation_step(self, batch, batch_idx=0):
        results = self.forward(batch)
        for k, v in results["metrics"].items():
            self.log(f"v_{k}", v, on_step=False, on_epoch=True, batch_size=self.hparams.batch_size, prog_bar=(k=="nll"))
        pass

    def test_step(self, batch, batch_idx=0):
        results = self.forward(batch)
        for k, v in results["metrics"].items():
            self.log(f"T_{k}", v, on_step=False, on_epoch=True, batch_size=self.hparams.batch_size, prog_bar=(k=="nll"))
        pass

    def predict_step(self, batch, batch_idx=0):
        results = self.forward(batch, metrics=False)

        return results["prediction"]["mean_clip"]




class HeuristicLearnerNLM(HeuristicLearner):
    def __init__(self, args):
        super().__init__(args)
        self.model = NLM(self.hparams.hidden_features,
                         self.hparams.out_features,
                         self.hparams.mlp_hidden_features,
                         self.hparams.residual,
                         self.hparams.exclude_self,
                         self.hparams.max_objs_cache_reduce_masks,
                         self.hparams.use_batch_norm,
                         self.hparams.activation_function)


    @staticmethod
    def id(args):
        return HeuristicLearner.id(args) + \
            (
                f"_{args.breadth}_{args.depth}_{args.hidden_features}_{args.mlp_hidden_features}"
                f"_{args.residual}_{args.exclude_self}"
                f"_{args.max_objs_cache_reduce_masks}_{args.use_batch_norm}_{args.activation_function}"
            )

    @staticmethod
    def parse_arguments(args):
        parser = argparse.ArgumentParser(description="Additional arguments for NLM")

        parser.add_argument('--breadth', default=3,type=int)
        parser.add_argument('--depth', default=5,type=int)
        parser.add_argument('--hidden-features', default=8,type=int)
        parser.add_argument('--mlp-hidden-features', default=0,type=int)
        parser.add_argument('--max-objs-cache-reduce-masks', default=30, type=int)

        parser.add_argument('--residual', choices=[None, "all", "input"])
        parser.add_argument('--exclude-self', default=True, type=eval)
        parser.add_argument('--use-batch-norm', action="store_true")
        parser.add_argument('--activation-function', default='sigmoid', choices=["sigmoid","relu"])

        parser.parse_args(args=args.rest, namespace=args)

        lp_parser = lifted_pddl.Parser()
        lp_parser.parse_domain(os.path.join(args.output, args.domain, 'domain.pddl'))

        args.hidden_features = [[args.hidden_features]*(args.breadth+1)]*(args.depth-1)

        if args.sigma == "learn":
            args.out_features = [2,0,0,0]
        elif args.sigma == "fixed":
            args.out_features = [1,0,0,0]
        else:
            raise "huh?"

        return args


    def _forward(self, model_specific_input, lb):
        inputs, num_objs_list = model_specific_input
        # A = len(inputs)    # max arity
        # B = len(inputs[0]) # batch size
        inputs = [ [ input_per_arity_per_batch.to(device=self.device)
                     if input_per_arity_per_batch is not None else None
                     for input_per_arity_per_batch in input_per_arity ]
                   for input_per_arity in inputs ]

        if self.hparams.l_as_input:
            with torch.no_grad():
                # Append the lower bounds to the nullary predicates
                nullary : list[Optional[torch.Tensor]] = inputs[0]
                if nullary[0] is None:
                    inputs[0] = lb.unsqueeze(-1)
                else:
                    new_nullary = [torch.cat((t, l)) for t, l in zip(nullary, lb.unsqueeze(-1))]
                    inputs[0] = new_nullary

        model_output = self.model(inputs, num_objs_list)

        nullary_output = model_output[0] # Obtain the nullary predicates

        mu = torch.cat([t[0].view(1) for t in nullary_output])
        if self.hparams.sigma == "learn":
            sigma = torch.cat([self._clip_sigma(t[1].view(1),0.1,10) for t in nullary_output])
        else:
            sigma = torch.full_like(mu, 0.5**(-0.5))
        return mu, sigma

    pass


class HeuristicLearnerRR(HeuristicLearner):
    def __init__(self, args):
        super().__init__(args)
        self.model = nn.LazyLinear(self.hparams.num_output_features)


    @staticmethod
    def id(args):
        return HeuristicLearner.id(args)

    @staticmethod
    def parse_arguments(args):
        # no specific hyperparameter
        if args.sigma == "learn":
            args.num_output_features = 2
        elif args.sigma == "fixed":
            args.num_output_features = 1
        else:
            raise "huh?"

        return args

    def _forward(self, model_specific_input, lb):
        features_tensor = model_specific_input
        features_tensor = features_tensor.to(device=self.device)
        if self.hparams.l_as_input:
            features_tensor = torch.cat((features_tensor, lb), dim=1)

        # Do a forward pass with self.model
        model_output = self.model(features_tensor)

        mu = model_output[:,0]
        if self.hparams.sigma == "learn":
            sigma = self._clip_sigma(model_output[:,1],0.1,10)
        else:
            sigma = torch.full_like(mu, 0.5**(-0.5))
        return mu, sigma

    pass



from strips_hgn.features.global_features import NumberOfNodesAndEdgesGlobalFeatureMapper
from strips_hgn.features.hyperedge_features import ComplexHyperedgeFeatureMapper
from strips_hgn.features.node_features import PropositionInStateAndGoal
from hypergraph_nets.models import EncodeProcessDecode
from hypergraph_nets.hypergraphs import HypergraphsTuple

class HeuristicLearnerHGN(HeuristicLearner):
    def __init__(self, args):
        super().__init__(args)
        self.model = EncodeProcessDecode(
            receiver_k = self.hparams.max_num_add_effects,
            sender_k = self.hparams.max_num_preconditions,
            hidden_size = self.hparams.hidden_size,
            edge_input_size=self.hparams.edge_input_size,
            node_input_size=self.hparams.node_input_size,
            global_input_size=self.hparams.global_input_size,
            global_output_size=self.hparams.global_output_size
        )


    @staticmethod
    def id(args):
        return HeuristicLearner.id(args) + \
            (
                f"_{args.num_recursion_steps}_{args.max_num_add_effects}_{args.max_num_preconditions}"
                f"_{args.hidden_size}_{args.edge_input_size}_{args.node_input_size}_{args.global_input_size}"
            )

    @staticmethod
    def parse_arguments(args):
        parser = argparse.ArgumentParser(description="Additional arguments for RR")

        parser.add_argument('--num-recursion-steps', default=10,)
        parser.add_argument('--max-num-add-effects', default=3,)
        parser.add_argument('--max-num-preconditions', default=7,)
        parser.add_argument('--hidden-size', default=32,)
        parser.add_argument('--edge-input-size', default=ComplexHyperedgeFeatureMapper.input_size(),)
        parser.add_argument('--node-input-size', default=PropositionInStateAndGoal.input_size(),)
        parser.add_argument('--global-input-size', default=NumberOfNodesAndEdgesGlobalFeatureMapper.input_size())

        parser.parse_args(args=args.rest, namespace=args)

        assert not args.l_as_input, "Currently, the STRIPS-HGN model is incompatible with the 'l_as_input' argument"

        if args.sigma == "learn":
            args.global_output_size = 2
        elif args.sigma == "fixed":
            args.global_output_size = 1
        else:
            raise "huh?"

        return args

    def _forward(self, model_specific_input, lb):
        hypergraphs_tuple : HypergraphsTuple = model_specific_input
        hypergraphs_tuple = HypergraphsTuple(*[
            elem.to(device=self.device)
            if isinstance(elem, torch.Tensor)
            else elem
            for elem in hypergraphs_tuple
        ])

        # Forward pass
        # Since pred_mode=False, we obtain the intermediate outputs
        # model_output has shape (num_steps, num_graphs, global_size)
        model_output = \
            torch.stack(
                self.model(hypergraphs_tuple,
                           steps=self.hparams.num_recursion_steps,
                           # When pred_mode=True, it returns only the last step of num_recursion_steps
                           # When pred_mode=False, it returns all steps
                           pred_mode=not self.training),
                dim=0)

        mu = model_output[:,:,0]
        if self.hparams.sigma == "learn":
            sigma = self._clip_sigma(model_output[:,:,1],0.1,10)
        else:
            sigma = torch.full_like(mu, 0.5**(-0.5))
        return mu, sigma

    pass

