import os

import numpy as np
import json
import torch
import wandb
import cvxpy as cp
from tqdm import tqdm, trange
import yaml


def compute_next_dual(eta, rho, dual, gradient, lambd):
    tilde_dual = dual - eta*gradient/rho/rho
    order = np.argsort(tilde_dual*rho)
    ordered_tilde_dual = tilde_dual[order]
    ordered_next_dual = map_layer(ordered_tilde_dual, rho[order], lambd)
    return ordered_next_dual[order.argsort()]

def map_layer(ordered_tilde_dual, rho, lambd):

    m = len(rho)
    answer = cp.Variable(m)
    objective = cp.Minimize(cp.sum_squares(cp.multiply(rho,answer) - cp.multiply(rho, ordered_tilde_dual)))
    #objective = cp.Minimize(cp.sum(cp.multiply(rho,answer) - cp.multiply(rho, ordered_tilde_dual)))
    constraints = []
    for i in range(1, m+1):
        constraints += [cp.sum(cp.multiply(rho[:i],answer[:i])) >= -lambd]
    prob = cp.Problem(objective, constraints)
    prob.solve()
    return answer.value


class Dual_Sampler(object):
    def __init__(self, p_size, train_len, AdjecentMatrix, batch_size, data_size):
        #super(Abstract_Sampler, self).__init__('Dual')
        self.p_size = p_size
        self.train_len = train_len
        self.rho = np.ones(self.p_size)
        self.AdjecentMatrix = AdjecentMatrix
        self.batch_size = batch_size
        self.data_size = data_size


        with open(os.path.join("sample_utils","{}_sample.yaml".format("Dual")), 'r') as yaml_file:
            self.hyper_parameters = yaml.safe_load(yaml_file)

        self.reset_parameters()


    def reset_parameters(self):
        self.update_len = self.hyper_parameters['update_epoch'] * self.train_len
        self.C_t = self.hyper_parameters['topk'] * self.update_len * self.rho
        self.mu = np.ones(self.p_size)
        self.gradient_cusum = 0

    def update_weight(self, items):
        ############ get weight
        B_t = np.sum(self.AdjecentMatrix[items], axis=1, keepdims=False)
        batch_weight = np.mean(B_t * self.mu, axis=-1, keepdims=False)
        ###Scale the weight to avoid it becomes too small
        batch_weight = (np.max(batch_weight) - batch_weight) * 100000 + 1.0

        #batch_weight = 1.0 - batch_weight / np.sum(batch_weight)

        #############update parameters
        D_t = np.sum(B_t, axis=0, keepdims=False)
        self.C_t = self.C_t - D_t
        gradient = -D_t / (self.batch_size * self.hyper_parameters['topk']) + self.C_t / (self.data_size * self.hyper_parameters['topk'])
        gradient = self.hyper_parameters['alpha'] * gradient + (1 - self.hyper_parameters['alpha']) * self.gradient_cusum
        self.gradient_cusum = gradient
        
        for g in range(1):
            self.mu = compute_next_dual(self.hyper_parameters['eta'], self.rho, self.mu, gradient, self.hyper_parameters['lambd'])


        self.update_len = self.update_len - self.batch_size
        if self.update_len <= 0:
            self.reset_parameters()
            return batch_weight, True

        else:
            return batch_weight, False


