import os
from tqdm import tqdm
from google.cloud import storage

import wandb

import numpy as np
from sklearn.utils.class_weight import compute_class_weight

import torch
from torch import nn
from torch.nn import functional
from torch.optim.lr_scheduler import CyclicLR

from utils import logs_handler

logger = logs_handler.get_logger(__name__)

NO_WEIGHT_DECAY_MODULES = [nn.LayerNorm, nn.Embedding]

def download_bucket_directory(bucket_name, source_directory, destination_directory, warn=True):
    storage_client = storage.Client.create_anonymous_client()
    bucket = storage_client.bucket(bucket_name)

    blobs = list(bucket.list_blobs(prefix=source_directory))
    blobs = list(filter(lambda s: s.name.endswith('.gz'), blobs))
    total_size = sum([blob.size for blob in blobs])
    pbar = tqdm(total=total_size, unit='B', unit_scale=True)
    ignored = 0
    for blob in blobs:
        pbar.set_description(blob.name)
        destination_file_path = os.path.join(destination_directory, blob.name)
        dirname = os.path.dirname(destination_file_path)
        os.makedirs(dirname, exist_ok=True)
        if not os.path.exists(destination_file_path):
            blob.download_to_filename(destination_file_path)
        elif warn:
            ignored += 1
            pbar.set_postfix(ignored=f'{ignored}/{len(blobs)}', 
                             reason=f'Destination file "{destination_file_path}" already exists.')
        pbar.update(blob.size)
        
def wandb_ready():
    return (wandb.run is not None)

def wandb_log(*args, **kwargs):
    if wandb_ready():
        wandb.log(*args, **kwargs)

def wandb_finish(*args, **kwargs):
    if wandb_ready():
        wandb.finish(*args, **kwargs)

def accumulate_losses(out_dict, average=True):
    losses_keys = list(filter(lambda k: k.endswith('_loss'), out_dict.keys()))
    assert 0 < len(losses_keys), '...'
    loss = 0.0
    loss_count = 0
    for key in losses_keys:
        if (out_dict[key] > 1e-7).all():
            loss_count += 1
        loss += out_dict[key]
    if average:
        loss /= loss_count
    return loss

def top_k_logits(logits, k, inf=True):
    topk = torch.topk(logits, k)
    min_value = topk.values[:,:,[-1]]
    out = logits.masked_fill(logits < min_value, -float('inf') if inf else float(-100))
    return out

def make_discrete_actions(logits, top_k=None, sample=False):
    b, t, num_action_tokens = logits.shape
    if top_k is not None:
        logits = top_k_logits(logits, top_k)
    probs = functional.softmax(logits, dim=-1)
    if sample:
        idx = torch.multinomial(probs.reshape(-1, num_action_tokens), num_samples=1)
    else:
        _, idx = torch.topk(probs, k=1, dim=-1)
    idx = idx.reshape(b, t)
    actions_mask = functional.one_hot(idx, num_action_tokens).type(torch.bool)
    probs = probs[actions_mask].reshape(b, t)
    return idx, probs

def make_continuous_actions(y, min_value=-1.0, max_value=1.0):
    b, t, num_action_tokens = y.shape
    assert num_action_tokens == 1, '...'
    actions = torch.clamp(y, min_value, max_value).reshape(b, t)
    certainty = torch.ones_like(y)
    return actions, certainty

def get_tokens_weights(tokens, num_tokens):
    unique_tokens, counts = np.unique(tokens, return_counts=True)
    logger.info(f'Tokens: {list(zip(unique_tokens, counts))}')
    tokens_weights = compute_class_weight(class_weight='balanced', classes=unique_tokens, y=tokens)
    tokens_weights = dict(zip(unique_tokens, tokens_weights))
    weights = [0.0] * num_tokens
    for token in unique_tokens:
        if token < num_tokens:
            weights[token] = tokens_weights[token]
    return weights

def mask_tensor_causal(tensor, value, forward=False, k=1):
    ndim = tensor.ndim
    pad = ((0, ) * (2 * (ndim - 2)))
    if forward:
        tensor = functional.pad(tensor[:, :-k], (*pad, k, 0), value=value)
    else:
        tensor = functional.pad(tensor[:, k:], (*pad, 0, k), value=value)
    return tensor

def mask_tensor_random(tensor, value, k=4):
    device = tensor.device
    ndim = tensor.ndim
    pad = ((0, ) * (2 * (ndim - 2)))
    view = (1, ) * (ndim - 2)
    b, t = tensor.shape[:2]
    nw, rem = t//k, 0
    if t%k != 0:
        nw += 1
        rem = nw * k - t
        t = nw * k
        tensor = functional.pad(tensor, (*pad, rem, 0), value=value)
    size = np.prod(tensor.shape[:2])//k
    perm = (torch.randperm(size) % k).reshape(b, nw, 1).to(device)
    mask = torch.zeros(tensor.shape[:2], dtype=torch.bool, device=device)
    mask = mask.reshape(b, nw, k)
    src = torch.full_like(mask, fill_value=True)
    mask = mask.scatter_(dim=2, index=perm, src=src)
    mask = mask.reshape(*tensor.shape[:2], *view)
    mask[:,rem:rem+2] = False
    masked_tensor = tensor.masked_fill(mask, value)
    masked_tensor = masked_tensor.reshape(*tensor.shape)
    return masked_tensor[:,rem:], mask[:,rem:]

def copy_tensor(tensor):
    return tensor.clone().detach().requires_grad_(False)

def named_modules(module):
    nodes = []
    for name, module in module.named_modules():
        nodes.append((name, module, not list(module.children())))
    return nodes

def build_parameters_modules_map(module):
    nodes = {}
    for mn, m, _ in named_modules(module):
        for pn, _ in m.named_parameters():
            fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
            if fpn not in nodes:
                nodes[fpn] = []
            nodes[fpn].append(m)
    return nodes

def add_no_decay_module(cls):
    global NO_WEIGHT_DECAY_MODULES
    NO_WEIGHT_DECAY_MODULES.append(cls)

def is_no_decay_parameter(modules_list, blacklist_modules=None):
    blacklist_modules = blacklist_modules or tuple(NO_WEIGHT_DECAY_MODULES)
    for m in modules_list:
        if isinstance(m, blacklist_modules):
            return True
    return False

def configure_adamw_optimizer(module, weight_decay=0.1, num_iterations=2000, 
                              min_lr=7e-5, max_lr=6e-4, **kwargs):
    decay = set()
    no_decay = set()

    # all biases & weights of blacklist-modules and unknown-params will NOT be weight decayed    
    for fpn, modules_list in build_parameters_modules_map(module).items():
        if fpn.endswith('bias'):
            no_decay.add(fpn)
        elif fpn.endswith('weight') and is_no_decay_parameter(modules_list):
            no_decay.add(fpn)
        elif fpn.endswith('weight'):
            decay.add(fpn)
        else:
            no_decay.add(fpn)
            
    param_dict = {pn: p for pn, p in module.named_parameters()}
        
    inter_params = decay & no_decay
    union_params = decay | no_decay

    assert len(inter_params) == 0, 'parameters %s made it into both decay/no_decay sets!' % (str(inter_params), )
    assert len(param_dict.keys() - union_params) == 0, 'parameters %s were not separated into either decay/no_decay set!' \
                                                % (str(param_dict.keys() - union_params), )

    optim_groups = [
        {
            'params': [param_dict[pn] for pn in sorted(list(decay))], 
            'weight_decay': weight_decay, 
        },
        
        {
            'params': [param_dict[pn] for pn in sorted(list(no_decay))], 
            'weight_decay': 0.0, 
        },
    ] 
    for optim_group in optim_groups:
        optim_group.update(kwargs)
    optimizer = torch.optim.AdamW(optim_groups)    

    steps = 500 if (num_iterations is None) else (num_iterations // 4)
    assert steps > 1, '...'
    logger.info(f'CyclicLR: (steps: {steps}, min_lr: {min_lr}, max_lr: {max_lr})')
    scheduler = CyclicLR(optimizer, min_lr, max_lr, steps, steps, cycle_momentum=False)
    return optimizer, scheduler

def init_weights(module):
    if isinstance(module, (nn.Linear, nn.Embedding)):
        module.weight.data.normal_(mean=0.0, std=0.02)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()
    elif isinstance(module, nn.LayerNorm):
        module.bias.data.zero_()
        module.weight.data.fill_(1.0)

def get_normalized_score(score, ref_score, random_score):
  return (score - random_score) / (ref_score - random_score)

def interpolate3d(tensor, depth, height, width):
    b, _, d, h, w = tensor.shape
    out_size = (depth, height, width)
    out_tensor = tensor.reshape(b,-1,d,h,w)
    out_tensor = functional.interpolate(out_tensor, size=out_size, mode='trilinear')
    out_tensor = out_tensor.reshape(b,-1,*out_size)
    return out_tensor
