# train passive model
import logging
import numpy as np
import os, cv2, time, copy
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from collections import Counter
from Network.network_utils import pytorch_model, run_optimizer, get_gradient
from ActualCausal.Inference.inference_utils import compute_distributional_distance
from ActualCausal.Train.train_utils import compute_likelihood_adaptive_lasso, compute_mean_adaptive_lasso, compute_mean_var_adaptive_lasso
from tianshou.data import Batch

def random_reassignment(batch_len, num_objects, query_size):
    return np.random.rand(batch_len, num_objects, query_size)

def perturbation_reassignment(query_reshaped, perturb_magnitude):
    return query_reshaped + (np.random.rand(*query_reshaped.shape) - 0.5) * 2 * perturb_magnitude

def generate_counterfactual_perturbation(idx_set, model, batch, reassignment_type, perturbation_magnitude): # generates a single counterfactual perturbation
    # assumes no padding, values normalized between -1, 1
    if reassignment_type == "null":
        batch.valid[idx_set] = 0 # treat the idx set as invalid
    else:
        query_reshaped = batch.obs.reshape(len(batch), model.extractor.num_objects, -1)
        if reassignment_type == "random": #TODO: support alternative reassignment strategies
            reassigned_values = random_reassignment(len(batch), model.extractor.num_objects, query_reshaped.shape[-1])
        elif reassignment_type == "perturb": #TODO: support alternative reassignment strategies
            reassigned_values = perturbation_reassignment(query_reshaped, perturbation_magnitude)
        else:
            reassigned_values = np.zeros(len(batch), model.extractor.num_objects, query_reshaped.shape[-1])
        query_reshaped[np.arange(len(batch)),idx_set,:model.extractor.pad_dim] = reassigned_values[np.arange(len(batch)),idx_set,:model.extractor.pad_dim]
        batch.obs = query_reshaped.reshape(len(batch), -1)
    return batch

def compute_counterfactual_cause(args, params, model, batch, result = None, keep_all=False):
    # performs random counterfactual perturbations, and determines if the change in outcome as a result is large or not
    # TODO: probably does not work for "all" versions of the code, uses model.name to determine name
    batch = batch
    if not result: actual_result = model.infer(batch, batch.valid, params.mask_mode, additional = [], keep_all=keep_all)
    else: actual_result = result
    omit_flags = actual_result[params.mask_mode].omit_flags
    # module = model.get_module(mode)
    # omit_flags = module.get_omit(batch, keep_all=keep_all, keep_invalid=False, use_name=model.target_name if mode != "all" else "")
    # get_omit(batch, keep_all=keep_all, keep_invalid=False, use_name=model.name if model.name != "all" else "")

    result = Batch()
    result.omit_flags = omit_flags
    result.utrace = batch.trace if params.mask_mode.find("all") != -1 else batch.trace[:, model.extractor.get_index([model.target_name])]
    result.utrace = result.utrace[omit_flags[0],0]
    
    bins, counterfactual_variance, one_rates, zero_rates = np.zeros((len(omit_flags[0]), model.extractor.num_objects)), np.zeros((len(omit_flags[0]), model.extractor.num_objects)), np.zeros((model.extractor.num_objects,)), np.zeros((model.extractor.num_objects,))
    midpoints = list()
    for i in range(model.extractor.num_objects):
        model_dists, _ = compute_counterfactual_model_dists([i], model, batch, args, args.infer.counterfactual.num_counterfactual, args.infer.counterfactual.distance_form, args.infer.counterfactual.reassignment_type, args.infer.counterfactual.perturbation_magnitude, params, actual_result=actual_result, keep_all=keep_all)
        counterfactual_variance[:,i] = np.mean(pytorch_model.unwrap(model_dists), axis=-1)
        # print(counterfactual_variance.shape, (result.utrace[...,0,i] == 1).shape, result.utrace.shape, params.mask_mode)
        # one_rates[i] = min(1, np.mean(pytorch_model.unwrap(counterfactual_variance[...,i][result.utrace[...,i] == 1])))
        # zero_rates[i] = min(1, np.mean(pytorch_model.unwrap(counterfactual_variance[...,i][result.utrace[...,i] == 0])))
        one_rates[i] = np.median(pytorch_model.unwrap(counterfactual_variance[...,i][result.utrace[...,i] == 1])) if np.sum(result.utrace[...,i] == 1) else 1
        zero_rates[i] = np.median(pytorch_model.unwrap(counterfactual_variance[...,i][result.utrace[...,i] == 0])) if np.sum(result.utrace[...,i] == 0) else 1
        midpoint = (one_rates[i] + zero_rates[i]) / 2
        if args.infer.counterfactual.select_ideal:
            bins[:,i] = counterfactual_variance[:,i] > midpoint
        else:
            bins[:,i] = counterfactual_variance[:,i] > args.infer.counterfactual.counterfactual_threshold
        midpoints.append(midpoint)
    midpoints = np.array(midpoints)
    print(midpoints)
    result.inter_masks, result.mask_logits = bins, counterfactual_variance
    result.inter_one_trace_rate = np.expand_dims(one_rates, axis=0)
    result.inter_zero_trace_rate = np.expand_dims(zero_rates, axis=0)
    result.bin_error = result.inter_masks - result.utrace if model.name != "all" else result.inter_masks - batch.trace[omit_flags[0]]
    result.total_error = np.abs(result.bin_error)
    result.valid = batch.valid
    result.trace = batch.trace[omit_flags[0]]
    return result

def compute_counterfactual_model_dists(idxes, model, batch, args, num_counterfactual, distance_form, reassignment_type, perturbation_magnitude, params, actual_result=None, keep_all= False):
    # computes the model distributional distances for num_counterfactual perturbations (more of a helper function)
    reassign_batch = copy.deepcopy(batch)
    model_dists = list()
    # compute the ground result
    if actual_result is not None: actual_result = model.infer(batch, batch.valid, params.mask_mode, additional = [], keep_all=keep_all)[params.mask_mode] # TODO: all evaluation not tested, but might be more efficient
    for j in range(num_counterfactual):
        # alter the batch and compute the model output
        reassign_batch = generate_counterfactual_perturbation(idxes, model, batch, reassignment_type, perturbation_magnitude)
        perturb_result = model.infer(reassign_batch, reassign_batch.valid, params.mask_mode, additional = [], keep_all=True)[params.mask_mode] # TODO: all evaluation not tested, but might be more efficient
        # model_dists.append(pytorch_model.unwrap(active_full_model_dists.mean(dim=-1).unsqueeze(-1) * done_flags).squeeze()) # don't use batch size of 1
        model_dists.append(compute_distributional_distance(distance_form, actual_result, perturb_result, cflags = actual_result.omit_flags[0]).squeeze())
    model_dists = torch.stack(model_dists, dim=1) # shapes: batch -> batch x num_counterfactuals
    return model_dists, actual_result

def splitting_loss(split_type, split_bias, model_dists):
    # penalize the ones for being too close to 0, and penalize the zeros for being too large
    if split_type == "ones": return torch.sigmoid(-model_dists.mean() + split_bias)
    else: return torch.sigmoid(model_dists.mean() + split_bias)
    
NAME_SPLIT_MAP = {"ones": 1, "zeros": 0, "all": 1}
def compute_adaptive_splitting_lambda(split_args, split_name, params, result, batch):
    # gets the adaptive lasso, we can move some of the values to params (scheduled) if they turn out to be useful
    not_use_adaptive= split_args.adaptive_splitting[0] < 0
    base_value = split_args.splitting_lambda[NAME_SPLIT_MAP[split_name]]
    baseline_likelihood = params.converged_active_loss_value if "converged_active_loss_value" in params else 3.5 * result.log_probs.shape[-1]
    adaptive_lasso = params.adative_lasso
    bias = split_args.adaptive_splitting[0]
    flatten_factor = split_args.adaptive_splitting[1]
    if split_args.adaptive_splitting_type == "likelihood": lasso_lambda = compute_likelihood_adaptive_lasso(not_use_adaptive, base_value, adaptive_lasso, result, batch, baseline_likelihood, bias, flatten_factor)
    elif split_args.adaptive_splitting_type == "mean": lasso_lambda = compute_mean_adaptive_lasso(not_use_adaptive, base_value, adaptive_lasso, result, batch, flatten_factor)
    elif split_args.adaptive_splitting_type == "meanvar": lasso_lambda = compute_mean_var_adaptive_lasso(not_use_adaptive, base_value, adaptive_lasso, result, batch, flatten_factor)
    return lasso_lambda


def compute_splitting_losses(args, params, model, batch, result = None, keep_all=False, mode = 'full'):
    # generates counterfactuals on the ones/zeros, computes the splitting comparison with the actual distribution
    # then returns the mean of the 1/0 splitting of the modeled distributions
    # can compute four kinds of losses: 
        # Penalizing local perturbations above a certain magnitude
        # penalizing lack of change below a threshold for ones
        # penalize changes in zeros above a threshold
        # assuming that the batch states are supposed to be non-interaction, penalize deviation from null distribution/local perturbations (Note num_counterfactual should be 1)
    module = model.get_module(mode)
    omit_flags = module.get_omit(batch, keep_all=keep_all, keep_invalid=False, use_name=model.target_name if mode != "all" else "")
    result = Batch()
    result.trace = batch.trace[omit_flags[0]]
    result.utrace = batch.trace[omit_flags[0], module.extractor.get_index([model.target_name]) if mode != "all" else np.arange(model.extractor.num_objects)]
    result.valid = batch.valid
    result.omit_flags =  omit_flags
    check_trace = result.utrace if mode != "all" else result.trace

    split_args = args.inter.regularization.splitting

    adaptive_lambdas = dict()

    for i in range(1 + int(split_args.splitting_type == "both")):
        # either use the one indices, zero indices, both separately or all the indices for counterfactual perturbation
        split_name = split_args.splitting_type if split_args.splitting_type != "both" else ("ones" if i % 2 == 0 else "zeros")
        if split_name == "ones": idxes = np.nonzero(check_trace)
        elif split_name == "zeros": idxes = np.nonzero(1-check_trace)
        else: idxes = np.arange(model.extractor.num_objects) # assume split_name == all


        # perturbation is the type, or use the reassignment_type for ones/zeros for flexibility
        reassignment_type = split_args.reassignment_type if split_args.reassignment_type in ["perturb", "null"] else split_args.reassignment_type
        num_counterfactual = split_args.num_counterfactual if split_args.reassignment_type != "null" else 1 # only need one reassignment for null
        model_dists, actual_result = compute_counterfactual_model_dists(idxes, model, batch, args, num_counterfactual, split_args.distance_form, reassignment_type, split_args.perturbation_magnitude, params, keep_all=keep_all)
        result.split_loss_raw[split_name] = splitting_loss(split_args.splitting_type,split_args.splitting_bias[NAME_SPLIT_MAP[split_name]], model_dists, )
        adaptive_lambdas[split_name] = compute_adaptive_splitting_lambda(split_args, split_name, params, result, batch)
        # print("split_loss", result.split_loss.mean(), model_dists.mean())
    result.split_loss = sum([result.split_loss[n].mean() * adaptive_lambdas[n] for n in result.split_loss.keys()])
    return result
