import torch
import numpy as np
import math
from data.brain_data.matlab_to_python_brain_data import apply_subnetwork_mask
from metrics import classification_metrics


def percent_change_metrics(prior_metrics, prediction_metrics, metrics=['nse', 'se', 'ae', 'acc', 'error', 'mcc']): #'hinge'
    percent_change_metrics = {}
    for prior_channel_name, prior_channel_subnetwork_metrics_dict in prior_metrics.items():
        percent_change_metrics[prior_channel_name]={}
        for subnetwork_name, prior_channel_subnetwork_metrics in prior_channel_subnetwork_metrics_dict.items():
            percent_change_metrics[prior_channel_name][subnetwork_name]={}
            # Things moving in the good direction make positive percentage changes
            # **** incorporate 'mean' key into this! ****
            for m in metrics:
                if any(a in m for a in ['err', 'se', 'ae', 'hinge']): #want to minimize
                    percent_change_metrics[prior_channel_name][subnetwork_name][m] = 100*(prior_channel_subnetwork_metrics[m]-prediction_metrics[subnetwork_name][m])/prior_channel_subnetwork_metrics[m]
                elif any(a in m for a in ['acc', 'f1', 'F1', 'mcc']): #want to maximize
                    percent_change_metrics[prior_channel_name][subnetwork_name][m] = 100*(prediction_metrics[subnetwork_name][m]-prior_channel_subnetwork_metrics[m])/prior_channel_subnetwork_metrics[m]
                else:
                    raise ValueError(f'What direction should unrecognized metric {m} go??')
    return percent_change_metrics


def print_subnet_perf_dict(subnetwork_metrics_dict, indents=1, convert_to_percent = [], metrics2print=['se', 'ae', 'nmse', 'error', 'mcc']):
    for subnetwork_name, subnetwork_metrics_dict in subnetwork_metrics_dict.items():
        if subnetwork_name == 'epoch':
            print(f'{indent}{subnetwork_name}: {padding_spaces}{subnetwork_metrics_dict}')
            continue
        metrics_dict = {}
        for metric_name in metrics2print:
            if metric_name in subnetwork_metrics_dict:
                val = subnetwork_metrics_dict[metric_name]
                metrics_dict[metric_name] = val.item() if torch.is_tensor(val) else val
        padding_spaces = " " * (9 - len(subnetwork_name))
        indent = '\t'*indents
        s = ''
        for metric_name in metrics2print:
            if metric_name in subnetwork_metrics_dict:
                val = subnetwork_metrics_dict[metric_name]
                if metric_name in convert_to_percent:
                    val = val*100
                s += metric_name + ': ' + f'{val:3.5f}, '

        print(f'{indent}{subnetwork_name}: {padding_spaces}{s}')


def best_subnetwork_at_best_metric(model, prior_metrics, subnetworks = ['frontal', 'temporal', 'occipital', 'parietal', 'full'], metrics=['se', 'ae', 'error', 'mcc'], indents=1):
    print('\nBest Performance @ Epoch: Val/% Change/epoch')
    for i, subnetwork in enumerate(subnetworks):
        padding_spaces = " " * (9 - len(subnetwork))
        indent = '\t' * indents
        if i>0:
            print('')
        print(f'{indent}{subnetwork}: {padding_spaces}', end='')

        for metric in metrics:
            if metric in ['se', 'ae', 'error']:
                maximize = False
            else:
                maximize = True
            all_subnetworks_metrics_at_epoch = model.best_metrics(sort_metric=metric, sort_subnetwork=subnetwork, maximize=maximize)[0]
            best_metric_epoch = all_subnetworks_metrics_at_epoch['epoch']
            best_metric_val = all_subnetworks_metrics_at_epoch[subnetwork][metric]
            prior_metric_val = prior_metrics[model.prior_construction][subnetwork][metric]
            if maximize:
                percent_change = 100*(best_metric_val - prior_metric_val)/prior_metric_val
            else:
                percent_change = 100 * (prior_metric_val - best_metric_val) / prior_metric_val
            percent_change = format_color(f'{percent_change.abs():3.5f}%', color='green' if percent_change > 0 else 'red')

            if metric in ['acc', 'error']:
                best_metric_val = 100*best_metric_val
            #print(f'{metric}: {best_metric_val:3.4f}|{percent_change:3.4f}%|{best_metric_epoch}, ', end='')
            s = f'{metric}: {best_metric_val:3.4f}|{percent_change}%|{best_metric_epoch}, '
            print(f"{s:<32}", end="")
    print('\n')


# Each slice of x has a corrseponding id. There may be slices with identical id's. Only
# include first occuring slice with corresponding id.
def filter_repeats(x, ids):
    assert len(x.shape) == 3, f'filter_repeats takes 3D input, not {x.shape}'
    len_x = x.shape[0]
    len_ids = ids.shape[0]
    assert len_x == len_ids, f'x and ids must have same size: x {x.shape}, ids: {ids.shape}'

    # find unique scs in train_scs using subject_ids
    already_seen, idxs = set(), []
    for i in range(len_x):
        if ids[i].item() not in already_seen:
            already_seen.add(ids[i].item())
            idxs.append(i)

    unique_x = x[idxs] # take slices at unique inidices
    return unique_x


def round_up(n, decimals=0):
    multiplier = 10 ** decimals
    return math.ceil(n * multiplier) / multiplier


# https://stackoverflow.com/questions/287871/how-to-print-colored-text-to-the-terminal
def format_color(s_in, color):
    # changes background color around text
    assert color in ['red', 'green', 'yellow', 'blue']
    if color == 'red':
        f = '6;30;41m'
    elif color == 'green':
        f = '6;30;42m'
    elif color == 'yellow':
        f = '6;30;43m'
    elif color == 'blue':
        f = '6;30;44m'

    s_out = f'\x1b[{f}' + s_in + '\x1b[0m'
    return s_out


def apply_mask(a: torch.tensor, mask: torch.tensor):
    return apply_subnetwork_mask(a, mask)


def shallowest_layer_all_zero(model):
    #starting from beginning of model, check if the layer output all zeros.
    #  Return layer depth or -1
    for i, module in enumerate(model.layers):
        if module.output_zeros:
            print(f"\t\t{i}th layer!")
            return i
    return -1


def resample_params(module, stdv_scaling=1/3, alpha_mean=1, beta_mean=0):
    for name, param in module.named_parameters():
        with torch.no_grad():
            if ('tau' in name) and module.tau_info['type'] == 'scalar' and module.tau_info['learn']:
                param.copy_(param * 0.1) # reduce it for less all zeros
            if 'alpha' in name:
                new_param = module.stdv_scaling * torch.randn_like(param) + module.init_alpha_sample_mean
                param.copy_(new_param)
            if 'beta' in name:
                new_param = module.stdv_scaling * torch.randn_like(param) + module.init_beta_sample_mean
                param.copy_(new_param)
            if 'poly' in name:
                new_param = module.stdv_scaling * torch.randn_like(param)
                new_param[:, 1, :] = 1
                param.copy_(new_param)
            if 'k' in name:
                new_param = module.stdv_scaling * torch.randn_like(param) + module.init_alpha_sample_mean
                param.copy_(new_param)
            if 'commute' in name:
                new_param = module.stdv_scaling * torch.randn_like(param) + module.init_alpha_sample_mean
                param.copy_(new_param)


@torch.no_grad()
def clamp_tau(model, large_tau):
    for name, W in model.named_parameters():
        if 'tau' in name:
                W.clamp_(min=0, max=large_tau)


##### Experiment with threshold to optimize metric #####
# Given list of thresholds, see which one optimizes given metric
@torch.no_grad()
def best_threshold_by_metric(thresholds, y, y_hat, non_neg: bool, metric: str, reduction=torch.nanmean):
    assert metric in ['acc', 'error', 'f1', 'mcc']
    assert torch.is_tensor(y) and torch.is_tensor(y_hat)
    assert y.shape == y_hat.shape
    assert reduction in [torch.nanmean, torch.nanmedian, torch.nansum]
    metric_vals = torch.zeros(len(thresholds))
    for i, threshold in enumerate(thresholds):
        m = classification_metrics(y_hat=y_hat, y=y, threshold=threshold, non_neg=non_neg)
        metric_vals[i] = reduction(m[metric])

    if metric == 'error':
        return thresholds[np.nanargmin(metric_vals)]
    else:
        return thresholds[np.nanargmax(metric_vals)]

##############


# returns hinge_loss of each scan in tensor
# FOR y in {0,+1} NOT {-1, +1}
def hinge_loss(y, y_hat, margin=0.2, already_bin=False, per_edge=True, slope=1):
    y_bin = y if already_bin else (y > 0)
    if torch.is_tensor(y_bin):
        assert y_bin.dtype == torch.bool
    else:
        assert y_bin.dtype == np.bool
    assert (y.ndim == 3) and (y.shape[-2]==y.shape[-1])
    assert (y_hat.ndim == 3) and (y_hat.shape[-2]==y_hat.shape[-1])
    loss_when_label_zero = torch.maximum(torch.zeros_like(y_hat), y_hat - margin) # assume all y_hat >= 0
    loss_when_label_one = torch.maximum(torch.zeros_like(y_hat), -y_hat + (1 - margin))
    hinge_loss = torch.where(condition=y_bin, input=loss_when_label_one, other=loss_when_label_zero) # outputs input where true
    hinge_loss = slope*hinge_loss # more slope = more punishment of errorprediction_metrics_for_each_subnetwork
    hinge_loss_per_scan = torch.sum(hinge_loss, dim=(1, 2))

    # each hinge loss is now normalized by number of possible edges: hinge loss/# possible edges
    if per_edge:
        bs, N = y.shape[0:2]
        total_possible_edges = N * N - N  # ignore diagonals
        return hinge_loss_per_scan/total_possible_edges
    else:
        return hinge_loss_per_scan



##### General matrix/tensor utils #####
def construct_prior(train_scs: torch.Tensor, frac_contains: float = 1.0, reduction: str = "median", plt_prior: bool = False):
    # construct prior based *only* on training set
    # then use prior to directly predict sc structure on validation and test set for baseline
    # frac_contains :: float in [0,1] determining what fraction of incoming scs should contain an edge
    #   in order to keep it in the prior

    # find intersection mask
    prior_edge_mask = (train_scs > 0).sum(dim=0) >= (frac_contains * len(train_scs))

    # find edge values
    if reduction == 'median':
        [prior_edge_values, _] = train_scs.median(dim=0)
    elif reduction == 'mean':
        prior_edge_values = train_scs.mean(dim=0)
    else:
        raise ValueError(f'reduction {reduction} not implimented')

    # only keep edge value in mask
    prior = torch.where(prior_edge_mask, prior_edge_values, torch.zeros(1))

    return prior.view(1, prior.shape[-2], prior.shape[-1])


if __name__ == "__main__":
    print(f'gdn utils')

