import os
import torch
import math
# import mlflow
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

from multiquery_randomized_smoothing.src.transformations import stn
from multiquery_randomized_smoothing.src.noises import noises
from multiquery_randomized_smoothing.src.models import architectures
from multiquery_randomized_smoothing.src.train_utils import save_param, make_directories
from statsmodels.stats.proportion import proportion_confint

class ARSmodel(nn.Module):
    """Adaptive Randomized Smoothing (ARS) main model class. This class creates our proposed architecture
    which differs vanilla randomized smoothing in 2 components: (1) transformation layer and (2) multi-query
    setup. This class also has methods for monte carlo predictions and computing the certified radii during 
    certification time.

    # if Adaptive, architecture consists of 3 components:

    - transformation layer
    - noise layer
    - base arch (off-the-shelf DNN model)

    # if Vanilla, architecture consists of 2 components only:

    - noise layer
    - base arch (off-the-shelf DNN model)
    """

    # to abstain on a prediction, ARSModel returns this int
    ABSTAIN = -1

    def __init__(self, config):
        super(ARSmodel, self).__init__()
        
        self.config = config

        # initial general setup
        self.mode = config["mode"]
        self.num_queries = config["num_queries"]

        # initialize noise layer (common to vanilla and adaptive mode)
        if config["arch"]["noise_distribution"] == "expinf":
            self.noise_layer = noises.ExpInf(config)
        elif config["arch"]["noise_distribution"] == "gaussian_inf":
            self.noise_layer = noises.GaussianInf(config)
        
        # initialize base architecture (common to vanilla and adaptive mode)
        self.base_arch = architectures.get_architecture(config["arch"]["base_model"],
                                                        config["dataset"]["name"],
                                                        config["dataset"]["input_channels"],
                                                        config["dataset"]["num_classes"],
                                                        config["device"])

        # initialize setup only for adaptive mode
        if self.mode == "adaptive":

            # initialize transformation layer
            if config["arch"]["transformation"] == "stn":
                self.transformation_layer = stn.STN(config)

            # if number of queries is > 1 (currently only supports 2 queries)
            if self.num_queries > 1:

                self.split = config["budget_split"]
                
                # budget splitting
                if self.split == "fixed":
                    # pre-set fixed split of budget fractions
                    self.budget_fractions = config["budget_fractions"] 
                elif self.split == "dynamic":
                    # learn the budget split online
                    self.fq_budget_frac_param = nn.Parameter(torch.randn(1, device=config["device"]).squeeze())

                # fc layers post concatenating outputs from multiple queries
                self.fc_final = nn.Linear(20, config["dataset"]["num_classes"])

    def _count_arr(self, arr: np.ndarray, length: int) -> np.ndarray:
        counts = np.zeros(length, dtype=int)
        for idx in arr:
            counts[idx] += 1
        return counts

    # def predict(self, x: torch.tensor, n: int, alpha: float, batch_size: int) -> int:
    #     """ Monte Carlo algorithm for evaluating the prediction of g at x.  With probability at least 1 - alpha, the
    #     class returned by this method will equal g(x).

    #     This function uses the hypothesis test described in https://arxiv.org/abs/1610.03944
    #     for identifying the top category of a multinomial distribution.

    #     :param x: the input [channel x height x width]
    #     :param n: the number of Monte Carlo samples to use
    #     :param alpha: the failure probability
    #     :param batch_size: batch size to use when evaluating the base classifier
    #     :return: the predicted class, or ABSTAIN
    #     """
    #     self.base_classifier.eval()
    #     counts = self.monte_carlo_predictions(x, n, batch_size)
    #     top2 = counts.argsort()[::-1][:2]
    #     count1 = counts[top2[0]]
    #     count2 = counts[top2[1]]
    #     if binom_test(count1, count1 + count2, p=0.5) > alpha:
    #         return Smooth.ABSTAIN
    #     else:
    #         return top2[0]

    def _lower_confidence_bound(self, nA: int, n_cert: int, alpha: float) -> float:
        """ Returns a (1 - alpha) lower confidence bound on a bernoulli proportion.

        This function uses the Clopper-Pearson method.

        :param NA: the number of "successes"
        :param N: the number of total draws
        :param alpha: the confidence level
        :return: a lower bound on the binomial proportion which holds true w.p at least (1 - alpha) over the samples
        """
        return proportion_confint(nA, n_cert, alpha=alpha, method="beta")[0]

    def certify(self, x, n_pred, n_cert, alpha, adv, batch_size):
        """
        """
        # get n_pred monte carlo samples for prediction
        counts_prediction = self.monte_carlo_predictions(x, n_pred, batch_size)
        prediction = counts_prediction.argmax().item()
        
        # get n_cert monte carlo samples for certification
        counts_estimation = self.monte_carlo_predictions(x, n_cert, batch_size)
        nA = counts_estimation[prediction].item()
        prob_lb = self._lower_confidence_bound(nA, n_cert, alpha)

        if adv == "l1":
            # certify l1
            pass
        elif adv == "l2":
            # certify l2
            pass
        elif adv == "linf":
            # certify linf
            radius = self.noise_layer.certify_linf(x, prob_lb)
        elif adv == "all_lp":
            # certify l1, l2 and linf
            pass

        if prob_lb < 0.5:
            prediction = -1 # abstain

        return prediction, radius

    def monte_carlo_predictions(self, x: torch.tensor, num: int, batch_size: int):
        """Sample the base classifier's prediction under noisy corruptions of the input x.

        :param x: the input [channel x width x height]
        :param num: number of samples to collect
        :param batch_size:
        :return: an ndarray[int] of length num_classes containing the per-class counts
        """

        with torch.no_grad():
            counts = np.zeros(self.config["dataset"]["num_classes"], dtype=int)
            for iter in range(math.ceil(num / batch_size)):

                this_batch_size = min(batch_size, num)
                num -= this_batch_size

                input_batch = x.repeat((this_batch_size, 1, 1, 1))

                # get model predictions
                logging_trackers = {
                    "iter": iter,
                    "epoch_num": self.config["train_epochs"],
                    "mode": "certify"
                }
                outputs = self.forward(input_batch, logging_trackers)

                predictions = outputs.argmax(1)
                counts += self._count_arr(predictions.cpu().numpy(), self.config["dataset"]["num_classes"])
                
            return counts

    def transform_noise_interpolate(self,
                                    x,
                                    prev_q_op = None,
                                    budget_frac: float = 0.0,
                                    query_number: int = 1,
                                    ckpt_dict: dict = {},
                                    logging_trackers = {}):

        # save the original image
        ckpt_dict["images"]["original"].append(x)

        # transform images if mode is adaptive
        if self.mode == "adaptive":
            x = self.transformation_layer(x,
                                          query_number,
                                          prev_q_op,
                                          ckpt_dict)

            if query_number == 1 and not self.config["first_query_w_transform"]:
                pass
            else:
                ckpt_dict["images"]["transformed"].append(x)

        # sample noise (noise is added to the image within the sample function call)
        x = self.noise_layer.sample(x.view(len(x), -1), budget_frac, query_number, logging_trackers).view(x.shape)
        
        # interpolate back to initial dimensions if mode is adaptive
        if self.mode == "adaptive":
            if query_number == 1 and not self.config["first_query_w_transform"]:
                pass
            else:
                ckpt_dict["images"]["transformed_noisy"].append(x)
                x = F.interpolate(x, self.config["dataset"]["initial_height"])
    
        # save the final noisy image
        ckpt_dict["images"]["final_noisy"].append(x)

        return ckpt_dict

    def query_pass(self,
                   x: torch.Tensor = None,
                   prev_q_op: torch.Tensor = None,
                   budget_frac: float = 0.5,
                   query_number: int = 1,
                   logging_trackers: dict = {}):
        """
        one query forward pass
        """

        print("query {}:".format(query_number))
        
        ckpt_dict = {
            "stn_params": {
                "theta": list(),
                "grid": list(),
                "grid_height": list(),
                "grid_width": list()
            },
            "images": {
                "original": list(),
                "transformed": list(),
                "transformed_noisy": list(),
                "final_noisy": list(),
            }
        }

        if query_number == 1:
            ckpt_dict = self.transform_noise_interpolate(x=x,
                                                         prev_q_op=None,
                                                         budget_frac=budget_frac,
                                                         query_number=query_number,
                                                         ckpt_dict=ckpt_dict,
                                                         logging_trackers=logging_trackers)
        else:
            # for multi-query; right now we are processing one image at a time
            for (orig_image, prev_image) in zip(x, prev_q_op):
                orig_image = torch.unsqueeze(orig_image, 0)
                prev_image = torch.unsqueeze(prev_image, 0)
                ckpt_dict = self.transform_noise_interpolate(x=orig_image,
                                                             prev_q_op=prev_image,
                                                             budget_frac=budget_frac,
                                                             query_number=query_number,
                                                             ckpt_dict=ckpt_dict,
                                                             logging_trackers=logging_trackers)
        
        # required to stack together per-example outputs stored in a list outputted by second query
        # ckpt_dict = {k1: {k2: torch.vstack(v2) for k2, v2 in v1.items() if v2} for k1, v1 in ckpt_dict.items()}
        
        # pass noisy transformed image to base DNN model
        ckpt_dict["images"]["final_noisy"] = torch.vstack(ckpt_dict["images"]["final_noisy"])
        output_pred = self.base_arch(ckpt_dict["images"]["final_noisy"])

        # save stn parameters
        if logging_trackers["epoch_num"] % 10 == 0 or logging_trackers["epoch_num"] == self.config["train"]["epochs"] - 1:
            save_path = os.path.join(self.config["outdir"], 
                                    "params",
                                    "epoch_"+str(logging_trackers["epoch_num"]), 
                                    logging_trackers["mode"],
                                    "iter_"+str(logging_trackers["iter"]))
            make_directories(save_path)
            torch.save(ckpt_dict["stn_params"],
                    os.path.join(save_path, "stn_params_"+str(query_number)+".pt"))

            # save the first batch of test images along the query pipeline
            if logging_trackers["mode"] == "test" and logging_trackers["iter"]==0:
                torch.save(ckpt_dict["images"],
                            os.path.join(save_path, "images_"+str(query_number)+".pt"))

        return ckpt_dict, output_pred

    def forward(self, x, logging_trackers=dict()):
        """
        model's forward pass
        """

        # fetch the budget fraction for first query
        if self.mode == "adaptive" and self.num_queries > 1:
            if self.split == "fixed":
                # fixed budget split
                budget_frac_1 = self.budget_fractions[0]
                budget_frac_2 = self.budget_fractions[1]
            elif self.split == "dynamic":
                # dynamic/adaptive budget split
                budget_frac_1 = nn.Sigmoid()(self.fq_budget_frac_param)
                budget_frac_2 = torch.sqrt(1 - torch.square(budget_frac_1))
        else:
            budget_frac_1 = 1 # entire budget for a single query

        # log first query's budget
        print("first query budget fraction {}".format(budget_frac_1))
        save_path = os.path.join(self.config["outdir"], 
                                 "params",
                                 "epoch_"+str(logging_trackers["epoch_num"]), 
                                 logging_trackers["mode"],
                                 "iter_"+str(logging_trackers["iter"]))
        make_directories(save_path)
        torch.save(budget_frac_1, os.path.join(save_path, "budget_frac_1.pt"))

        # first query's forward pass
        ckpt_dict, output_pred = self.query_pass(x,
                                                 budget_frac=budget_frac_1,
                                                 query_number=1,
                                                 logging_trackers=logging_trackers)

        # remaining queries (for now, it supports 2 queries only)
        if self.mode == "adaptive" and self.num_queries > 1:
            
            op_pred_list = list()
            op_pred_list.append(output_pred)

            # log second query's budget
            print("second query budget fraction {}".format(budget_frac_2))
            torch.save(budget_frac_2, os.path.join(save_path, "budget_frac_2.pt"))

            # for _ in range(self.num_queries-1): --> implement this if we want to experiment with more than 2 queries
            
            # second query's forward pass
            ckpt_dict, output_pred = self.query_pass(x,
                                                     prev_q_op=ckpt_dict["images"]["final_noisy"],
                                                     budget_frac=budget_frac_2,
                                                     query_number=2,
                                                     logging_trackers=logging_trackers)
            op_pred_list.append(output_pred)

            # concatenate all output_predictions
            # output_pred = torch.cat([output_pred_1, output_pred_2], dim=1)
            output_pred = torch.hstack(op_pred_list)

            # post concatenation
            output_pred = self.fc_final(output_pred)
            # output_pred = F.relu(self.fc1(output_pred))
            # output_pred = F.dropout(output_pred, training=self.training)
            # output_pred = self.fc2(output_pred)
        
        return F.log_softmax(output_pred, dim=1)