import os
import numpy as np
import copy

import torchvision.models as models

import torch
import torch.nn as nn
import torch.distributions as D

from multiquery_randomized_smoothing.src.models import architectures
from multiquery_randomized_smoothing.src.train_utils import log, get_image_size, get_mask_shape

def mask_penalty(masks):
    # penalty = torch.where(torch.logical_or(masks < 0, masks > 1), masks**2, masks*0)
    penalty = torch.where(masks < 0, torch.pow(masks,2), torch.where(masks > 1, torch.pow(masks-1, 2), torch.mul(masks, 0)))
    return penalty.mean()

class TWO_QUERY_ARCH(nn.Module):
    def __init__(self, args):
        super(TWO_QUERY_ARCH, self).__init__()

        self.args = args

        if "cuda" in str(self.args.device): 
            self.device_string = self.args.device
        else:
            self.device_string = f'cuda:{self.args.device}'

        # compute vanilla sigma (which is the init sigma for adaptive)
        self.image_size = get_image_size(args)
        # image_size = args.resolution
        self.image_dims = self.image_size * self.image_size * 3
        # self.args.total_sigma = (args.linf_pert * np.sqrt(self.image_dims)) / args.mu

        # if learning an average mask in first query, initialize a random or identity mask
        self.first_query_with_mask = args.first_query_with_mask
        if self.first_query_with_mask:
            mask_shape = get_mask_shape(args, self.image_size)
            if args.mask_init == "random":
                self.first_query_mask = nn.Parameter(torch.rand(mask_shape), requires_grad=True)
            elif args.mask_init == "identity":
                self.first_query_mask = nn.Parameter(torch.ones(mask_shape), requires_grad=True)

        # initialize second query mask model
        self.second_query_mask_model = architectures.get_architecture(arch=args.second_query_mask_model,
                                                                      prepend_preprocess_layer=True,
                                                                      prepend_normalize_layer=False,
                                                                      mask_output=args.mask_output)

        
        # load weights of second_query_mask_model that are learnt to output an identity mask
        if args.mask_init == "random":
            pass # do nothing; model will be initialized randomly; masks will begin at random
        elif args.mask_init == "identity":
            # self.second_query_mask_model.load_state_dict(checkpoint)
            if args.mask_output == "penalty":
                if args.second_query_mask_model in ["modified_resnet"]:
                    # for modified_resnet
                    self.second_query_mask_model.final_conv_layer.weight.data.fill_(0)
                    self.second_query_mask_model.final_conv_layer.bias.data.fill_(1)                
                elif args.second_query_mask_model in ["modified_fcn8", "modified_fcn16", "modified_fcn32"]:
                    # for FCNs
                    self.second_query_mask_model.final[-1].weight.data.fill_(0)
                    self.second_query_mask_model.final[-1].bias.data.fill_(1)

        # budget splitting mechanism
        if args.budget_split == "fixed":
            self.first_query_budget_frac = torch.tensor(np.sqrt(args.first_query_budget_frac), device=args.device)
        elif args.budget_split == "learnt":
            self.first_query_budget_frac = nn.Parameter(torch.tensor(np.sqrt(args.first_query_budget_frac), device=args.device), requires_grad=True)
            
        # finally, initialize base classifier (common to vanilla and adaptive mode)
        # bring back for non-imagenet exp.
        self.base_classifier = architectures.get_architecture(arch=args.base_classifier,
                                                              prepend_preprocess_layer=True,
                                                              prepend_normalize_layer=True,
                                                              dataset=args.dataset,
                                                              input_size=self.image_size,
                                                              input_channels=3,
                                                              num_classes=args.num_classes)




    def forward(self, x, logging_trackers: dict = {}):

        x_original = copy.deepcopy(x)

        if self.args.total_sigma > 0:
            # implies add noise
            # FIRST QUERY
            if self.args.first_query_with_mask:
                x = torch.mul(x, self.first_query_mask)
                if self.args.wm_across_channels == "same":
                    norm_1 = torch.sqrt(torch.tensor(3)) * self.first_query_mask.norm(2)
                elif self.args.wm_across_channels == "different":
                    norm_1 = self.first_query_mask.norm(2)
                # computing sigma for first query based on mask's l2 norm
                sigma_1 = (self.args.total_sigma / self.first_query_budget_frac) * (norm_1 / np.sqrt(self.image_dims))
            else:
                # -- if no first query mask, then add noise directly to inputs (without transformation)
                # -- in the second term below, numerator and denominator are same because no transformation
                sigma_1 = (self.args.total_sigma / self.first_query_budget_frac) * (np.sqrt(self.image_dims) / np.sqrt(self.image_dims))

            norm_dist = D.Normal(loc=0., scale=sigma_1)
            noise = norm_dist.rsample(x.shape).to(self.device_string)
            x += noise # adding the sampled noise

        # SECOND QUERY BEGINS
        # - pass the noisy image from first query into second query mask model
        # - get the per-input mask and
        # - multiply original images with those corresponding masks 
        if self.args.mask_output == "penalty":
            second_query_masks = self.second_query_mask_model(x)
            
            # compute mean mask penalty across the batch
            mean_mask_penalty = mask_penalty(second_query_masks)
            
            # second, clamp the masks
            clamped_second_query_masks = second_query_masks.clamp(min=0., max=1.)
        elif self.args.mask_output == "sigmoid":
            if self.args.mask_recon > 0 :
                clamped_second_query_masks_with_recon = self.second_query_mask_model(x).to(torch.float32)
                assert clamped_second_query_masks_with_recon.shape[1] == 4
                clamped_second_query_masks = clamped_second_query_masks_with_recon[:,0:1,:,:] # to get only the mask of the output, [:,1:4,:,:] is the reconstruction
            else: 
            # print("using no reconstructioin!")
                clamped_second_query_masks= self.second_query_mask_model(x).to(torch.float32)
                assert clamped_second_query_masks.shape[1] == 1
            # print(clamped_second_query_masks.shape)


        # multiply raw input image with the obtained masks
        # print(clamped_second_query_masks.shape)
        x_transformed = torch.mul(x_original, clamped_second_query_masks)

        # compute norm from mask
        if self.args.wm_across_channels == "same":
            norm_2 = torch.sqrt(torch.tensor(3)) * torch.norm(clamped_second_query_masks.view(clamped_second_query_masks.shape[0], -1), p=2, dim=1)
        elif self.args.wm_across_channels == "different":
            norm_2 = torch.norm(clamped_second_query_masks.view(clamped_second_query_masks.shape[0], -1), p=2, dim=1)


        # computing the budget for second query (according to GDP formulation)
        second_query_budget_frac = torch.sqrt(1 - torch.square(self.first_query_budget_frac)).to(self.device_string)

        # compute sigma from norms and query budget
        sigma_2 = (self.args.total_sigma / second_query_budget_frac) * (norm_2 / np.sqrt(self.image_dims))

        # adding noise (parallely)
        norm_dist = D.Normal(loc=0., scale=sigma_2)
        noise = norm_dist.rsample(x_transformed.shape[1:]).permute(3,0,1,2).to(self.device_string)
        x_transformed += noise

        if self.args.average_queries:
            if self.args.averaging_style == "old":
                # old averaging (averaging without lambda_1 * clamped_second_query_masks)
                pre_averaging_x_transformed = x_transformed.clone().detach()

                # computing the weights
                sigma_1 = sigma_1.repeat(sigma_2.shape[0]).reshape(sigma_2.shape[0],1,1,1).repeat(1,1,self.image_size,self.image_size)
                sigma_2 = sigma_2.reshape(sigma_2.shape[0],1,1,1).repeat(1,1,self.image_size,self.image_size)
                
                denominator = ((clamped_second_query_masks ** 2) * (sigma_1 ** 2)) + (sigma_2 ** 2)

                w1 = sigma_2 ** 2 / denominator
                w2 = ((sigma_1 ** 2) * (clamped_second_query_masks)) / denominator

                # average noisy images
                x_transformed *= w2
                x_transformed += (w1 * x)
            elif self.args.averaging_style == "new":
                # new style averaging
                pre_averaging_x_transformed = x_transformed.clone().detach()
                            
                # computing the weights
                sigma_1 = sigma_1.repeat(sigma_2.shape[0]).reshape(sigma_2.shape[0],1,1,1).repeat(1,1,self.image_size,self.image_size)
                sigma_2 = sigma_2.reshape(sigma_2.shape[0],1,1,1).repeat(1,1,self.image_size,self.image_size)
                            
                denominator = ((clamped_second_query_masks ** 2) * (sigma_1 ** 2)) + (sigma_2 ** 2)

                # lambda_1 + lambda_2 = 1
                lambda_1 = (sigma_2 ** 2) / denominator
                lambda_2 = ((clamped_second_query_masks ** 2) * (sigma_1 ** 2)) / denominator

                # average noisy images
                x_transformed *= lambda_2
                x_transformed += (lambda_1 * clamped_second_query_masks * x)

        # Perform the usual forward pass (passing the second query's noisy image)
        output_pred = self.base_classifier(x_transformed)

        # store budget split only when during train (since it doesnt change during test time)
        if logging_trackers["mode"] == 'train':
            log(logging_trackers["budget_query_split_log_file"], "{} \t {}".format(self.first_query_budget_frac.item(), second_query_budget_frac))
        
        if logging_trackers["mode"] == 'train' or logging_trackers["mode"] == 'test':
            
            # store the sigma split
            log(logging_trackers["sigma_log_file"], "{} \t {} \t {}".format(sigma_1.mean(), sigma_2.mean(), self.args.total_sigma))

            # because at test time, we are computing accuracy over n noisy samples, we will be storing a subset of them 
            if logging_trackers["epoch"] % 10 == 0 or logging_trackers["epoch"] == 1:  
                num_to_store = 100
                if logging_trackers["batch_idx"] == 0:
                    torch.save(x_original[:num_to_store], os.path.join(logging_trackers["saved_dir"], "original_images.pt"))
                    if self.first_query_with_mask:
                        torch.save(self.first_query_mask, os.path.join(logging_trackers["saved_dir"], "first_query_mask.pt"))
                    torch.save(x[:num_to_store], os.path.join(logging_trackers["saved_dir"], "q_1_images.pt"))
                    torch.save(clamped_second_query_masks[:num_to_store], os.path.join(logging_trackers["saved_dir"], "post_clamped_second_query_masks.pt"))
                    torch.save(sigma_2[:num_to_store], os.path.join(logging_trackers["saved_dir"], "sigma_2.pt"))
                    if self.args.average_queries:
                        torch.save(pre_averaging_x_transformed[:num_to_store], os.path.join(logging_trackers["saved_dir"], "q_2_images_before_averaging.pt"))
                        torch.save(x_transformed[:num_to_store], os.path.join(logging_trackers["saved_dir"], "q_2_images_after_averaging.pt"))
                    else:
                        torch.save(x_transformed[:num_to_store], os.path.join(logging_trackers["saved_dir"], "q_2_images.pt"))
        # not using mask penalty currently
        # if self.args.mask_output == "penalty":
        #     return output_pred, mean_mask_penalty
        return output_pred
