import os
import sys
import shutil
import random
import numpy as np
from distutils.dir_util import copy_tree
import errno
from os import path as osp

import torch
import torch.nn as nn
import yaml


def setup_seed(seed: int):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    np.random.seed(seed)

    torch.backends.cudnn.enabled = False
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True



def create_subdirs(sub_dir):
    os.mkdir(sub_dir)
    os.mkdir(os.path.join(sub_dir, "checkpoint"))


def write_to_file(file, data, option):
    with open(file, option) as f:
        f.write(data)


def clone_results_to_latest_subdir(src, dst):
    if not os.path.exists(dst):
        os.mkdir(dst)
    copy_tree(src, dst)


# ref:https://github.com/allenai/hidden-networks/blob/master/configs/parser.py
def trim_preceding_hyphens(st):
    i = 0
    while st[i] == "-":
        i += 1

    return st[i:]


def arg_to_varname(st: str):
    st = trim_preceding_hyphens(st)
    st = st.replace("-", "_")

    return st.split("=")[0]


def argv_to_vars(argv):
    var_names = []
    for arg in argv:
        if arg.startswith("-") and arg_to_varname(arg) != "config":
            var_names.append(arg_to_varname(arg))

    return var_names


# ref: https://github.com/allenai/hidden-networks/blob/master/args.py
def parse_configs_file(args):
    # get commands from command line
    override_args = argv_to_vars(sys.argv)

    # load yaml file
    yaml_txt = open(args.configs).read()

    # override args
    loaded_yaml = yaml.load(yaml_txt, Loader=yaml.FullLoader)
    for v in override_args:
        loaded_yaml[v] = getattr(args, v)

    print(f"=> Reading YAML config from {args.configs}")
    args.__dict__.update(loaded_yaml)


class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self, name, fmt=":f"):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
        return fmtstr.format(**self.__dict__)


class ProgressMeter(object):
    def __init__(self, num_batches, meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print("\t".join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = "{:" + str(num_digits) + "d}"
        return "[" + fmt + "/" + fmt.format(num_batches) + "]"

    def write_to_tensorboard(self, writer, prefix, global_step):
        for meter in self.meters:
            writer.add_scalar(f"{prefix}/{meter.name}", meter.val, global_step)
            
class Logger(object):
    """Writes console output to external text file.
    Imported from `<https://github.com/Cysu/open-reid/blob/master/reid/utils/logging.py>`_
    Args:
        fpath (str): directory to save logging file.
    Examples::
       >>> import sys
       >>> import os
       >>> import os.path as osp
       >>> from torchreid.utils import Logger
       >>> save_dir = 'log/resnet50-softmax-market1501'
       >>> log_name = 'train.log'
       >>> sys.stdout = Logger(osp.join(args.save_dir, log_name))
    """

    def __init__(self, fpath=None):
        self.console = sys.stdout
        self.file = None
        if fpath is not None:
            mkdir_if_missing(osp.dirname(fpath))
            self.file = open(fpath, 'w')

    def __del__(self):
        self.close()

    def __enter__(self):
        pass

    def __exit__(self, *args):
        self.close()

    def write(self, msg):
        self.console.write(msg)
        if self.file is not None:
            self.file.write(msg)

    def flush(self):
        self.console.flush()
        if self.file is not None:
            self.file.flush()
            os.fsync(self.file.fileno())

    def close(self):
        self.console.close()
        if self.file is not None:
            self.file.close()
            
    def close_file(self):
        if self.file is not None:
            self.file.close()
            
def mkdir_if_missing(directory):
    if not osp.exists(directory):
        try:
            os.makedirs(directory)
        except OSError as e:
            if e.errno != errno.EEXIST:
                raise
            

class FairMetrics():
    def __init__(self) -> None:
        self.reset()
        
    def reset(self):
        self.pred_y = []
        self.s = []
        self.y = []
        
    def update(self, pred_y, s, y):
        self.pred_y.append(pred_y)
        self.s.append(s)
        self.y.append(y)
        
    def count_stats(self, verbose=False):
        pred_y = torch.cat(self.pred_y)
        s = torch.cat(self.s)
        y = torch.cat(self.y)
        sp_tp = 0
        sn_tp = 0
        sp_tn = 0
        sn_tn = 0
        sp_t = 0
        sn_t = 0
        sp_n = 0
        sn_n = 0
        
        t_mask = y.eq(1.0)
        n_mask = y.eq(0.0)
        sp_mask = s.eq(1.0)
        sn_mask = s.eq(0.0)
        
        sp_tp += pred_y.masked_select((sp_mask * t_mask)).eq(y.data.masked_select((sp_mask * t_mask))).cpu().sum()
        sn_tp += pred_y.masked_select((sn_mask * t_mask)).eq(y.data.masked_select((sn_mask * t_mask))).cpu().sum()
        
        sp_tn += pred_y.masked_select((sp_mask * n_mask)).eq(y.data.masked_select((sp_mask * n_mask))).cpu().sum()
        sn_tn += pred_y.masked_select((sn_mask * n_mask)).eq(y.data.masked_select((sn_mask * n_mask))).cpu().sum()
        
        
        sp_t += (sp_mask * t_mask).float().sum()
        sn_t += (sn_mask * t_mask).float().sum()
        
        sp_n += (sp_mask * n_mask).float().sum()
        sn_n += (sn_mask * n_mask).float().sum()
        
        rd = (sp_tp + sp_n - sp_tn) / (sp_t + sp_n) - (sn_tp + sn_n - sn_tn) / (sn_t + sn_n)
        
        eo1 = sp_tp / sp_t  - sn_tp / sn_t
        eo0 = (sp_n - sp_tn) / sp_n - (sn_n - sn_tn) / sn_n
        
        if verbose:
            print('TP_up: {}, FP_up: {}, TP_p: {}, FP_p: {}'.format(
                    sp_tp, sp_n - sp_tn, sn_tp, sn_n - sn_tn))
            print('TN_up: {}, FN_up: {}, TN_p: {}, FN_p: {}'.format(
                    sp_tn, sp_t - sp_tp, sn_tn, sn_t - sn_tp))
            
            print(f'rd:{rd:.4f}, eo1:{eo1:.4f}, eo0:{eo0:.4f}')
        
        # Modify for causal    
        # return (sp_tp + sp_n - sp_tn), (sp_t + sp_n), (sn_tp + sn_n - sn_tn), (sn_t + sn_n)
        return rd, (sp_tp + sp_n - sp_tn) / (sp_t + sp_n), (sn_tp + sn_n - sn_tn) / (sn_t + sn_n)
    
def get_parameter_number(model):
    total_num = sum(p.numel() for p in model.parameters())
    trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print('# trainable parameters/ # total parameters')
    print(f'{trainable_num}/{total_num}')

