from collections import deque
import numpy as np
import torch
import math
import random


def update_params(optim, loss, networks, retain_graph=False,
                  grad_cliping=None):
    optim.zero_grad()
    loss.backward(retain_graph=retain_graph)
    # Clip norms of gradients to stebilize training.
    if grad_cliping:
        for net in networks:
            torch.nn.utils.clip_grad_norm_(net.parameters(), grad_cliping)
    optim.step()


def disable_gradients(network):
    # Disable calculations of gradients.
    for param in network.parameters():
        param.requires_grad = False


def calculate_huber_loss(td_errors, kappa=1.0):
    return torch.where(
        td_errors.abs() <= kappa,
        0.5 * td_errors.pow(2),
        kappa * (td_errors.abs() - 0.5 * kappa))


def calculate_quantile_huber_loss(td_errors, taus, weights=None, kappa=1.0):
    assert not taus.requires_grad
    batch_size, N, N_dash = td_errors.shape

    # Calculate huber loss element-wisely.
    element_wise_huber_loss = calculate_huber_loss(td_errors, kappa)
    assert element_wise_huber_loss.shape == (
        batch_size, N, N_dash)

    # Calculate quantile huber loss element-wisely.
    element_wise_quantile_huber_loss = torch.abs(
        taus[..., None] - (td_errors.detach() < 0).float()
        ) * element_wise_huber_loss / kappa
    assert element_wise_quantile_huber_loss.shape == (
        batch_size, N, N_dash)

    # Quantile huber loss.
    batch_quantile_huber_loss = element_wise_quantile_huber_loss.sum(
        dim=1).mean(dim=1, keepdim=True)
    assert batch_quantile_huber_loss.shape == (batch_size, 1)

    if weights is not None:
        quantile_huber_loss = (batch_quantile_huber_loss * weights).mean()
    else:
        quantile_huber_loss = batch_quantile_huber_loss.mean()

    return quantile_huber_loss


def evaluate_quantile_at_action(s_quantiles, actions):
    assert s_quantiles.shape[0] == actions.shape[0]

    batch_size = s_quantiles.shape[0]
    N = s_quantiles.shape[1]

    # Expand actions into (batch_size, N, 1).
    action_index = actions[..., None].expand(batch_size, N, 1)

    # Calculate quantile values at specified actions.
    sa_quantiles = s_quantiles.gather(dim=2, index=action_index)

    return sa_quantiles


class RunningMeanStats:

    def __init__(self, n=10):
        self.n = n
        self.stats = deque(maxlen=n)

    def append(self, x):
        self.stats.append(x)

    def get(self):
        return np.mean(self.stats)


class LinearAnneaer: # epsilon_decay를 위한 class

    def __init__(self, start_value, end_value, num_steps):
        assert num_steps > 0 and isinstance(num_steps, int)

        self.steps = 0
        self.start_value = start_value
        self.end_value = end_value
        self.num_steps = num_steps

        self.a = (self.end_value - self.start_value) / self.num_steps
        self.b = self.start_value

    def step(self):
        self.steps = min(self.num_steps, self.steps + 1)

    def get(self):
        assert 0 <= self.steps <= self.num_steps
        return self.a * self.steps + self.b

class HoeffdingAnneaer: #coefficient_decay를 위한 class

    def __init__(self, start_value):
        self.steps = 1
        self.start_value = start_value    

    def step(self):
        self.steps = self.steps +1
    
    def get(self):
        assert 0 <= self.steps
        self.a = self.start_value * math.sqrt(math.log(self.steps)/self.steps)
        return self.a

class NoisyHoeffdingAnneaer:

    def __init__(self, start_value):
        self.steps =1
        self.start_value = start_value

    def step(self):
        self.steps += 1

    def get(self):
        assert 0 <= self.steps
        self.a = self.start_value * math.sqrt(math.log(self.steps)/self.steps)
        self.b = random.normalvariate(0, self.a)
        return self.b


class NoisyHoeffdingAnneaer1:

    def __init__(self, start_value, end_value=0.05):
        self.steps =1
        self.start_value = start_value
        self.end_value = end_value

    def step(self):
        self.steps += 1

    def get(self):
        assert 0 <= self.steps
        self.a = self.start_value * math.sqrt(math.log(self.steps)/self.steps)
        if self.a > self.end_value:
            self.b = random.normalvariate(0, self.a)
            return self.b
        else:
            return self.end_value

class NoisyHoeffdingAnneaer2: #Our Candidate

    def __init__(self, start_value, end_value=0):
        self.steps =1
        self.start_value = start_value
        self.end_value = end_value

    def step(self):
        self.steps += 1

    def get(self):
        assert 0 <= self.steps
        self.a = self.start_value * math.sqrt(math.log(self.steps)/self.steps)
        if self.a > self.end_value:
            self.b = random.normalvariate(0, self.a)
            return self.b
        else:
            self.b = random.normalvariate(0, self.end_value)
            return self.b

class DeltaAnneaer:

    def __init__(self, start_value, end_value=0, N=None, beta=1, epsilon = 1e-3):
        self.num_steps =1
        self.start_value = start_value
        self.end_value = end_value
        self.N = N
        self.beta = beta
        self.epsilon = epsilon
        assert not self.beta > 1

    def step(self):
        self.num_steps += 1
        if not self.beta ==1: #beta-contraction condition
            self.beta *= beta
        

    def get(self):
        assert 0 <= self.num_steps
        self.delta = self.start_value/ (self.num_steps **(1+ self.epsilon))
        # Sample xi from Uniform([1-Delta, 1+Delta], size = N)
        # self.xi = 2 * self.start_value * np.random.sample(self.N) + 1 - self.start_value
        return self.delta


# test = NoisyHoeffdingAnneaer(start_value=50)
# test.steps = 1
# positive = 0
# negative = 0
# while test.steps <39999999:
#     test.step()
#     if test.steps > 39999990 :
#         print(test.get())
#     # if test.steps % 100 ==0:
#     #     print("step :",test.steps)
#     #     print("value :", test.get())
# test = DeltaAnneaer(start_value=50, N =200, beta=1)
