import argparse
import logging
import random

import numpy as np
import torch
import tqdm
from collections import OrderedDict, namedtuple
import math

def write_results(writer, prefix, results, step):
    for k, v in results.items():
        if v.shape.numel() > 1:
            writer.add_histogram(f'{prefix}/{k}', v, step)
        else:
            writer.add_scalar(f'{prefix}/{k}', v, step)
    loss_str = f"{results['test_avg_loss']:.4f}"
    if 'clean_test_avg_loss' in results:
        loss_str += f"({results['clean_test_avg_loss']:.4f})"
    if 'finetune_test_avg_loss' in results:
        loss_str += f" -> {results['finetune_test_avg_loss']:.4f}"
    log_str = f"{prefix} Step: {step}, AVG Loss: {loss_str}"
    if 'test_avg_acc' in results:
        acc_str = f"{results['test_avg_acc']:.4f}"
        if 'clean_test_avg_acc' in results:
            acc_str += f"({results['clean_test_avg_acc']:.4f})"
        if 'finetune_test_avg_acc' in results:
            acc_str += f" -> {results['finetune_test_avg_acc']:.4f}"
        log_str += f", AVG Acc: {acc_str}"
    logging.info(log_str)

def eval_model(evaluator, node_list, adv_node_list=[], split='test'):
    results = {}
    curr_results = evaluator(node_list, split)
    vals = curr_results.values()
    all_loss = torch.stack([val['loss'] for val in vals])
    results[f'{split}_all_loss'] = all_loss
    results[f'{split}_avg_loss'] = torch.mean(all_loss)
    results[f'{split}_max_loss'] = torch.max(all_loss)
    results[f'{split}_min_loss'] = torch.min(all_loss)
    results[f'{split}_std_loss'] = torch.std(all_loss)
    if 'acc' in curr_results[node_list[0]]:
        all_acc = torch.stack([val['acc'] for val in vals])
        results[f'{split}_all_acc'] = all_acc
        results[f'{split}_avg_acc'] = torch.mean(all_acc)
        results[f'{split}_max_acc'] = torch.max(all_acc)
        results[f'{split}_min_acc'] = torch.min(all_acc)
        results[f'{split}_std_acc'] = torch.std(all_acc)
    if len(adv_node_list) > 0:
        vals = [v for k, v in curr_results.items() if k not in adv_node_list]
        clean_all_loss = torch.stack([val['loss'] for val in vals])
        results[f'clean_{split}_all_loss'] = clean_all_loss
        results[f'clean_{split}_avg_loss'] = torch.mean(clean_all_loss)
        results[f'clean_{split}_max_loss'] = torch.max(clean_all_loss)
        results[f'clean_{split}_min_loss'] = torch.min(clean_all_loss)
        results[f'clean_{split}_std_loss'] = torch.std(clean_all_loss)
        if 'acc' in curr_results[node_list[0]]:
            clean_all_acc = torch.stack([val['acc'] for val in vals])
            results[f'clean_{split}_all_acc'] = clean_all_acc
            results[f'clean_{split}_avg_acc'] = torch.mean(clean_all_acc)
            results[f'clean_{split}_max_acc'] = torch.max(clean_all_acc)
            results[f'clean_{split}_min_acc'] = torch.min(clean_all_acc)
            results[f'clean_{split}_std_acc'] = torch.std(clean_all_acc)
    if 'finetune_loss' in curr_results[node_list[0]]:
        vals = [v for k, v in curr_results.items() if k not in adv_node_list]
        finetune_all_loss = torch.stack([val['finetune_loss'] for val in vals])
        results[f'finetune_{split}_all_loss'] = finetune_all_loss
        results[f'finetune_{split}_avg_loss'] = torch.mean(finetune_all_loss)
        results[f'finetune_{split}_max_loss'] = torch.max(finetune_all_loss)
        results[f'finetune_{split}_min_loss'] = torch.min(finetune_all_loss)
        results[f'finetune_{split}_std_loss'] = torch.std(finetune_all_loss)
        if 'acc' in curr_results[node_list[0]]:
            finetune_all_acc = torch.stack([val['finetune_acc']for val in vals])
            results[f'finetune_{split}_all_acc'] = finetune_all_acc
            results[f'finetune_{split}_avg_acc'] = torch.mean(finetune_all_acc)
            results[f'finetune_{split}_max_acc'] = torch.max(finetune_all_acc)
            results[f'finetune_{split}_min_acc'] = torch.min(finetune_all_acc)
            results[f'finetune_{split}_std_acc'] = torch.std(finetune_all_acc)
    return results

def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')


class TqdmLoggingHandler(logging.Handler):
    def __init__(self, level=logging.NOTSET):
        super().__init__(level)

    def emit(self, record):
        try:
            msg = self.format(record)
            tqdm.tqdm.write(msg)
            self.flush()
        except Exception:
            self.handleError(record)  


def set_logger(filename=None, level=logging.INFO):
    logging.basicConfig(
        filename=filename,
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
        level=level
    )
    logging.getLogger().addHandler(TqdmLoggingHandler(logging.INFO))


def set_seed(seed):
    """for reproducibility
    :param seed:
    :return:
    """
    np.random.seed(seed)
    random.seed(seed)

    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    # torch.backends.cudnn.enabled = False
    # torch.backends.cudnn.benchmark = False
    # torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = False


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def get_device(no_cuda=False, gpus='0'):
    if torch.backends.mps.is_available():
        return f"mps"
    elif torch.cuda.is_available() and not no_cuda:
        return f"cuda:{gpus}"
    return "cpu"


def freeze(model):
    for param in model.parameters():
        param.requires_grad = False


def unfreeze(model):
    for param in model.parameters():
        param.requires_grad = True

def detach_clone(param_dict):
    return OrderedDict([(n,p.detach().clone()) for n,p in param_dict.items()])


class CustomCosineAnnealingLR():
    def __init__(self, base_lr, T_max, eta_min=0, last_epoch=-1):
        self.base_lr = base_lr
        self.T_max = T_max
        self.eta_min = eta_min
        self.last_epoch = last_epoch
        self._step_count = 0
        self.current_lr = base_lr

    def get_lr(self):
        if self.last_epoch == 0:
            return self.base_lr
        elif self._step_count == 1 and self.last_epoch > 0:
            return self.eta_min + (self.base_lr - self.eta_min) * (1 + math.cos((self.last_epoch) * math.pi / self.T_max)) / 2
        elif (self.last_epoch - 1 - self.T_max) % (2 * self.T_max) == 0:
            return self.current_lr + (self.base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2
        return (1 + math.cos(math.pi * self.last_epoch / self.T_max)) / (1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max)) * (self.current_lr- self.eta_min) + self.eta_min

    def step(self):
        self._step_count += 1
        self.last_epoch += 1
        self.current_lr = self.get_lr()
        return self.current_lr

def sd_matrixing(state_dic, keys=None):
    """
    Turn state dic into a vector
    :param state_dic:
    :return:
    """
    param_vector = None
    for key, param in state_dic.items():
        if keys is not None and key not in keys:
            continue
        if param_vector is None:
            param_vector = param.flatten()
        else:
            if len(list(param.size())) == 0:
                param_vector = torch.cat((param_vector, param.view(1).type(torch.float32)), 0)
            else:
                param_vector = torch.cat((param_vector, param.flatten()), 0)
    return param_vector

class PiecewiseLinear(namedtuple('PiecewiseLinear', ('knots', 'vals'))):
    def __call__(self, t):
        return np.interp([t], self.knots, self.vals)[0]

class PiecewiseLinearLR(torch.optim.lr_scheduler._LRScheduler):

    def __init__(self, optimizer, max_epoch=20, batch_size=125, batch_count=3, last_epoch=-1, verbose=False):
        self.max_epoch = max_epoch
        self.batch_size = batch_size
        lr_schedule = PiecewiseLinear([0, 5, max_epoch], [0, 0.4, 0.001])
        self.lr = lambda step: lr_schedule(step / batch_count) / batch_size
        super(PiecewiseLinearLR, self).__init__(optimizer, last_epoch, verbose)

    def get_lr(self):
        lr = self.lr(self._step_count)
        return [lr for group in self.optimizer.param_groups]

    def _get_closed_form_lr(self):
        return 0.001