import logging
import os
import random
import time
import numpy as np
import copy
from tqdm import tqdm

import colorlog
import torch
from backpack import backpack, extend
from backpack.extensions import BatchGrad

from utils.parameters import Params

logger = logging.getLogger('logger')

def record_time(params: Params, t=None, name=None):
    if t and name and params.save_timing == name or params.save_timing is True:
        torch.cuda.synchronize()
        params.timing_data[name].append(round(1000 * (time.perf_counter() - t)))


def dict_html(dict_obj, current_time):
    out = ''
    for key, value in dict_obj.items():

        # filter out not needed parts:
        if key in ['poisoning_test', 'test_batch_size', 'discount_size',
                   'folder_path', 'log_interval',
                   'coefficient_transfer', 'grad_threshold']:
            continue

        out += f'<tr><td>{key}</td><td>{value}</td></tr>'
    output = f'<h4>Params for model: {current_time}:</h4><table>{out}</table>'
    return output


def poison_text(inputs, labels):
    inputs = inputs.clone()
    labels = labels.clone()
    for i in range(inputs.shape[0]):
        pos = random.randint(1, (inputs[i] == 102).nonzero().item() - 3)
        inputs[i, pos] = 3968
        inputs[i, pos + 1] = 3536
    labels = torch.ones_like(labels)
    return inputs, labels


def poison_text_test(inputs, labels):
    for i in range(inputs.shape[0]):
        pos = random.randint(1, inputs.shape[1] - 4)
        inputs[i, pos] = 3968
        inputs[i, pos + 1] = 3536
    labels.fill_(1)
    return True


def create_table(params: dict):
    data = "| name | value | \n |-----|-----|"

    for key, value in params.items():
        data += '\n' + f"| {key} | {value} |"

    return data


def get_current_git_hash():
    import git
    repo = git.Repo(search_parent_directories=True)
    sha = repo.head.object.hexsha
    return sha


def create_logger():
    """
        Setup the logging environment
    """
    log = logging.getLogger('logger')  # root logger
    log.setLevel(logging.INFO)
    format_str = '%(asctime)s - %(levelname)-8s - %(message)s'
    date_format = '%Y-%m-%d %H:%M:%S'
    if os.isatty(2):
        cformat = '%(log_color)s' + format_str
        colors = {'DEBUG': 'reset',
                  'INFO': 'reset',
                  'WARNING': 'bold_yellow',
                  'ERROR': 'bold_red',
                  'CRITICAL': 'bold_red'}
        formatter = colorlog.ColoredFormatter(cformat, date_format,
                                              log_colors=colors)
    else:
        formatter = logging.Formatter(format_str, date_format)
    stream_handler = logging.StreamHandler()
    stream_handler.setFormatter(formatter)
    log.addHandler(stream_handler)
    return logging.getLogger('logger')


def th(vector):
    return torch.tanh(vector) / 2 + 0.5


def thp(vector):
    return torch.tanh(vector) * 2.2

def module_activation(model, freqs, threshold=0):
    import numpy as np
    count = 0
    result = {}
    module_to_counts = {}
    dead_modules = []
    for name, param in model.named_parameters():
        num_params = np.prod(param.shape)
        module_to_counts[name] = num_params
        module_freqs = freqs[count:(count + num_params)]
        result[name] = len(module_freqs[module_freqs <= threshold])
        
        count += num_params

    for module in list(result.keys()):
        if result[module] != 0:
            logger.info(f"{module} with {result[module]} 0s over {module_to_counts[module]} ({round(result[module] * 100 / module_to_counts[module], 5)}%)")
            dead_modules.append(module)

    return dead_modules

def check_activation(hlpr, model):
    criterion = hlpr.task.criterion
    if model is None:
        model = hlpr.task.model
    model.eval()

    # dataloaders = {'test': hlpr.task.test_loader, 'train': hlpr.task.train_loader}
    dataloaders = {'test': hlpr.task.aux_loader}
    threshold = 0

    if hlpr.params.resume_model:
        tmp_split = hlpr.params.resume_model.split('/')[1]
        tmp_split = tmp_split.split('_')
        save_name = tmp_split[-2] + '_' + tmp_split[-1]
    else:
        save_name = 'pretrain'

    for _, (k, dataloader) in enumerate(dataloaders.items()): # only support one dataloader for now
        num_sample = len(dataloader)
        logger.warning(f'start {k} dataloader with length {len(dataloader)}')
        logger.warning(f'test on {len(dataloader.dataset)} samples')
        average_vals = None
        zero_frequency = None

        if not hlpr.params.fastcheck:
            model = extend(model, use_converter=True)
            criterion.reduction = 'mean'
            criterion = extend(criterion)
            num_sample = len(dataloader.dataset)

        sample_counter = 0
        for i, data in enumerate(tqdm(dataloader)):
            # if i == num_sample:
            #     break
            
            batch = hlpr.task.get_batch(i, data)
            sample_counter += len(batch.inputs)
            model.zero_grad()
            loss = criterion(model(batch.inputs), batch.labels)

            if not hlpr.params.fastcheck:
                with backpack(BatchGrad()):
                    loss.backward()

                for ii in range(len(batch.inputs)):
                    grads = []

                    for _, param in model.named_parameters():
                        grads.append(torch.flatten(param.grad_batch[ii]))
                    grads = torch.cat(grads).detach().cpu().numpy()

                    if average_vals is None:
                        average_vals = np.zeros_like(grads)
                        zero_frequency = np.zeros_like(grads)

                    average_vals += np.abs(grads)
                    zero_frequency[np.where(np.abs(grads) <= threshold)] += 1
            else:
                # faster with whole batch calculation, but not 
                loss = loss.mean()
                grads = torch.autograd.grad(loss, model.parameters())
                grads = torch.cat([torch.flatten(i) for i in grads]).detach().cpu().numpy()

                if average_vals is None:
                    average_vals = np.zeros_like(grads)
                    zero_frequency = np.zeros_like(grads)

                average_vals += np.abs(grads)
                zero_frequency[np.where(np.abs(grads) <= threshold)] += 1

        zero_frequency = zero_frequency / num_sample
        average_vals = average_vals / num_sample

        save_names = [f"saved_models/model_activation/{hlpr.params.task}_{k}_{threshold}_freq_{save_name}_fast.npy", 
                      f"saved_models/model_activation/{hlpr.params.task}_{k}_{threshold}_avg_{save_name}_fast.npy"]

        return zero_frequency, average_vals, save_names


def prepare_grad_mask(hlpr, model):
    # load mask
    if hlpr.params.mask_path:
        zero_frequency = np.load(hlpr.params.mask_path)
    else:
        zero_frequency, average_vals, _ = check_activation(hlpr, model)

    activate_frequency = torch.tensor(1 - zero_frequency)
    grad_mask = torch.zeros_like(activate_frequency)

    if hlpr.params.mask_module:
        module_mask = torch.zeros_like(activate_frequency)
        count = 0

        for i, param in enumerate(model.parameters()):
            # if i in hlpr.params.mask_module:
            if i <= hlpr.params.mask_module:
                module_mask[count:(count + param.numel())] = 1 # keep these as candidates for mask 
            count += param.numel()

        activate_frequency[module_mask == 0] = 1
        average_vals[module_mask == 0] = 1

    if hlpr.params.freq_threshold <= 1:
        grad_mask[activate_frequency <= hlpr.params.freq_threshold] = 1
        num_params = len(np.where(grad_mask == 1)[0])
        logger.warning(f"with frequency threshold {hlpr.params.freq_threshold}:")
        logger.warning(f"{num_params} params never been activated over {len(activate_frequency)} params ({round(num_params * 100 / len(activate_frequency), 5)}%)")
        hlpr.attack.params.dead_params = module_activation(model, activate_frequency, threshold=hlpr.params.freq_threshold)
    elif hlpr.params.freq_threshold <= 2:
        freq_threshold = hlpr.params.freq_threshold - 1
        freq_threshold = round(freq_threshold, 9)
        quantile = torch.kthvalue(activate_frequency, int(freq_threshold * len(grad_mask))).values

        if hlpr.params.strict_mask:
            grad_mask[np.where(activate_frequency <= quantile)[0][:int(freq_threshold * len(grad_mask))]] = 1
            activate_frequency[grad_mask == 0] = 1000
        else:
            grad_mask[activate_frequency <= quantile] = 1

        num_params = len(np.where(grad_mask == 1)[0])
        logger.warning(f"with {freq_threshold} most dead params (freq threshold at: {quantile}):")
        logger.warning(f"{num_params} params never been activated over {len(activate_frequency)} params ({round(num_params * 100 / len(activate_frequency), 5)}%)")
        hlpr.attack.params.dead_params = module_activation(model, activate_frequency, threshold=quantile)
    elif hlpr.params.freq_threshold <= 3:
        # use activation norm to choose dead params
        activate_frequency = torch.tensor(average_vals)
        freq_threshold = hlpr.params.freq_threshold - 2
        freq_threshold = round(freq_threshold, 9)
        quantile = torch.kthvalue(activate_frequency, int(freq_threshold * len(grad_mask))).values

        if hlpr.params.strict_mask:
            grad_mask[np.where(activate_frequency <= quantile)[0][:int(freq_threshold * len(grad_mask))]] = 1
            activate_frequency[grad_mask == 0] = 1000
        else:
            grad_mask[activate_frequency <= quantile] = 1

        num_params = len(np.where(grad_mask == 1)[0])
        logger.warning(f"with {freq_threshold} most dead params (val threshold at: {quantile}):")
        logger.warning(f"{num_params} params never been activated over {len(activate_frequency)} params ({round(num_params * 100 / len(activate_frequency), 5)}%)")
        hlpr.attack.params.dead_params = module_activation(model, activate_frequency, threshold=quantile)
    else:
        raise ValueError('Not a correct freq_threshold')

    
    hlpr.grad_mask_proportion = round(num_params * 100 / len(activate_frequency), 5)
    hlpr.activate_frequency = activate_frequency
    hlpr.grad_mask = grad_mask.to(hlpr.task.params.device)

    if hlpr.params.random_init:
        logger.warning("random initialize dead params")
        count = 0

        if hlpr.params.random_init >= 10:
            noise_multiplier = hlpr.params.random_init - 10
            dead_mask = copy.deepcopy(grad_mask)
        else:
            noise_multiplier = hlpr.params.random_init
            dead_mask = torch.zeros_like(activate_frequency) # only random initialize the toally dead params
            dead_mask[activate_frequency <= 0] = 1

        with torch.no_grad():
            for p in model.parameters():
                reshape_mask = torch.reshape(copy.deepcopy(dead_mask[count:(count + p.numel())]), p.shape).to(p.device)
                # p *= 1 - reshape_mask
                p += reshape_mask * torch.randn(p.shape, dtype=p.dtype).to(p.device) * noise_multiplier
                count += p.numel()


def prepare_clean_pca(hlpr, global_model, train_loader):
    import copy
    from torch.utils.data import DataLoader
    from torch.utils.data.sampler import SubsetRandomSampler

    logger.warning("Preparing for clean pca loss")
    logger.warning(f"There are total {hlpr.params.pca_num_grads} grads for pcas")

    pca_local_epochs = hlpr.params.pca_local_epochs
    num_samples = -(-len(train_loader.dataset) // hlpr.params.pca_num_grads)
    
    criterion = hlpr.task.criterion
    local_model = copy.deepcopy(global_model)
    local_model.train()

    # get gradients from different subsets of the dataset of clean data
    pca_grads = []
    for i in range(hlpr.params.pca_num_grads):
        logger.warning(f"Start {i + 1}/{hlpr.params.pca_num_grads}")
        indices = list(range(i * num_samples, min((i + 1) * num_samples, len(train_loader.dataset))))
        sub_train_loader = DataLoader(copy.deepcopy(train_loader.dataset), 
                                        batch_size=hlpr.params.batch_size, 
                                        num_workers=4, 
                                        pin_memory=True, 
                                        sampler=SubsetRandomSampler(indices)
                                    )
        hlpr.task.copy_params(global_model, local_model) # reset local model
        optimizer = hlpr.task.make_optimizer(local_model)

        for _ in range(pca_local_epochs):
            for ii, data in enumerate(sub_train_loader):
                batch = hlpr.task.get_batch(ii, data)
                local_model.zero_grad()
                loss = hlpr.attack.compute_blind_loss(local_model, criterion, batch, False)
                loss.backward()
                optimizer.step()

        local_grads = []
        for name, params in local_model.named_parameters():
            local_grads.append(torch.flatten(params - global_model.state_dict()[name]))
            
        local_grads = torch.cat(local_grads).detach().cpu().numpy()
        pca_grads.append(local_grads)
    
    pca_grads = np.array(pca_grads)

    # do pca, but not now
    hlpr.attack.clean_pcas = torch.tensor(pca_grads).to(hlpr.params.device)


def feat_est(hlpr, model, train_loader, module=None):
    model.train()

    est_feat = None
    for i, data in enumerate(train_loader):
        batch = hlpr.task.get_batch(i, data)
        with torch.no_grad():
            if module is None or '4' in module:
                feat = model.features(batch.inputs)
            elif 'before_relu' in module:
                feat = model.features_before_relu(batch.inputs)
            else:
                raise ValueError(f'Not a valid module: {module}')

        feat = torch.amax(feat, [0, -2, -1])
        if est_feat is None:
            est_feat = feat
        else:
            est_feat = torch.max(est_feat, feat)
    
    return est_feat


def weight_distribution(params):
    mean_weight = []

    for name, param in params.items():
        mean_weight.append(torch.abs(param).mean().item())

    return np.array(mean_weight)
