import os
import warnings
import collections
from dataclasses import dataclass
from typing import Optional, Tuple, Union

import numpy as np
import torch
from torch import nn
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F

WANDB_PADDING = -1


def custom_collate(batch):
    return {
        'query_texts': [data['query'] for data in batch],
        'response_texts': [data['response'] for data in batch],
        'query_ids': [data['query_ids'] for data in batch],
        'response_ids': [data['response_ids'] for data in batch],
        'instructions': [data['instruction'] for data in batch],
        'states': [data['state'] for data in batch],
        'thinks': [data['think'] for data in batch],
        'thinks_copy': [data['think_copy'] for data in batch],
        'actions': [data['action'] for data in batch],
        'histories': [data['history'] for data in batch],
        'rewards': [float(data['reward']) for data in batch],
        'indexes': [data['index'] for data in batch],
    }


class FixedKLController:
    def __init__(self, kl_coef):
        self.value = kl_coef

    def update(self, current, n_steps):
        pass


class AdaptiveKLController:
    def __init__(self, init_kl_coef, target, horizon):
        self.value = init_kl_coef
        self.target = target
        self.horizon = horizon

    def update(self, current, n_steps):
        target = self.target
        proportional_error = np.clip(current / target - 1, -0.2, 0.2)
        mult = 1 + proportional_error * n_steps / self.horizon
        self.value *= mult


def flatten_dict(nested, sep='/'):
    def rec(nest, prefix, into):
        for k, v in nest.items():
            if sep in k:
                raise ValueError(f"separator '{sep}' not allowed to be in key '{k}'")
            if isinstance(v, collections.abc.Mapping):
                rec(v, prefix + k + sep, into)
            else:
                into[prefix + k] = v
    flat = {}
    rec(nested, '', flat)
    return flat


def stack_dicts(stats_dicts):
    results = dict()
    for k in stats_dicts[0]:
        stats_list = [torch.flatten(d[k]) for d in stats_dicts]
        results[k] = pad_sequence(stats_list, batch_first=True, padding_value=WANDB_PADDING)
    return results


def stats_to_np(stats_dict):
    """Cast all torch.tensors in dict to numpy arrays."""
    new_dict = dict()
    for k, v in stats_dict.items():
        if isinstance(v, torch.Tensor):
            new_dict[k] = v.detach().cpu().numpy()
        else:
            new_dict[k] = v
        if np.isscalar(new_dict[k]):
            new_dict[k] = float(new_dict[k])
    return new_dict


def stats_print(stats_dict):
    for k, v in stats_dict.items():
        print(f"{k}: {v}")


def logprobs_from_logits(logits, ids):
    """ retrieve the logit of the current id """
    logp = F.log_softmax(logits, dim=2)
    logpy = torch.gather(logp, 2, ids.unsqueeze(2)).squeeze(-1)
    return logpy


def entropy_from_logits(logits):
    pd = F.softmax(logits, dim=-1)
    entropy = torch.logsumexp(logits, axis=-1) - torch.sum(pd*logits, axis=-1)
    return entropy


def whiten(values, shift_mean=True):
    mean, var = torch.mean(values), torch.var(values)
    whitened = (values - mean) * torch.rsqrt(var + 1e-8)
    if not shift_mean:
        whitened += mean
    return whitened


def clip_by_value(x, tensor_min, tensor_max):
    clipped = torch.max(torch.min(x, tensor_max), tensor_min)
    return clipped
