# -*- coding: utf-8 -*-
"""
Created on Mon Apr  8 14:30:32 2024

@author: ZJ
"""

import numpy as np
import pandas as pd
from utils import GetZScore, GetLambda, GetLambda_unbalanced, GetLambdaSimple,\
    HLM_balanced, HLM_unbalanced, GetAlpha
from utils_twostage import PrivateRange, clip,\
    nextpow2, GetU, epsilon_element

def norm(x):
    return np.sqrt(np.sum(x**2))

def normclip(x, c):
    xnorm = norm(x)
    if xnorm <= c:
        return x
    else:
        return c * x / xnorm
    
def run(D, T, Rc, epsilon, delta):
    """
    Our new approach by Huber loss minimization
    """
    d = D.shape[1]
    mu0 = HLM_balanced(D, T)
    lam = GetLambda(D, T, Rc, epsilon, delta)
    alpha = GetAlpha(d, epsilon, delta)
    W = np.random.normal(0, lam/alpha, (1,d))
    ans = normclip(mu0, Rc) + W
    randerr = d * (lam/alpha) ** 2
    return ans, mu0, randerr

def run_unbalanced(D, Tarray, weights, Rc, epsilon, delta):
    """
    For unbalanced data.
    """
    d = D.shape[1]
    mu0 = HLM_unbalanced(D, weights, Tarray)
    lam = GetLambda_unbalanced(D, weights, Tarray, Rc, epsilon, delta)
    alpha = GetAlpha(d, epsilon, delta)
    W = np.random.normal(0, lam / alpha, (1,d))
    ans = normclip(mu0, Rc) + W
    randerr = d * (lam / alpha) ** 2
    return ans, mu0, randerr

def run_twostage1d(X, tau, Rc, epsilon):
    a,b = PrivateRange(X, -Rc, Rc, tau, epsilon / 2)
    n = len(X)
    for i in range(n):
        X[i] = clip(X[i], a, b)
    mu0 = np.mean(X)
    lam = 8 * tau / (epsilon * n)
    W = np.random.laplace(0, lam)
    ans = mu0 + W
    randerr = 2 * lam ** 2
    return ans, mu0, randerr

def run_twostage(D, tau, Rc, epsilon, delta):
    """
    Baseline methods for comparison.
    One dimension: run two stage method directly.
    High dimension: apply Hadamard transform.
    """
    n, d = D.shape
    if d == 1:
        X = D.ravel()
        ans, mu0, randerr = run_twostage1d(X, tau, Rc, epsilon)
        return ans, mu0, randerr
    else:
        d_all = nextpow2(d)
        D = np.hstack([D, np.zeros((n, d_all - d))])
        U = GetU(d_all)
        D_transformed = D.dot(U)
        epsilon_p = epsilon_element(epsilon, delta, d_all)
        res = []
        means = []
        randerr = 0
        for j in range(d_all):
            ans_p, mu0_p, randerr_p = run_twostage1d(D_transformed[:,j],\
                                tau, Rc, epsilon_p)
            res.append(ans_p)
            means.append(mu0_p)
            randerr += randerr_p
        ans = np.array(res).reshape(1,-1).dot(U.T)[0,:d]
        mu0 = np.array(means).reshape(1,-1).dot(U.T)[0,:d]
        return ans, mu0, randerr
    
def generate(n, m, distribution):
    if distribution == 1:
        X = np.random.uniform(-1,1, (n, m, 1))
    elif distribution == 2:
        X = np.random.normal(0,1, (n, m, 1))
    elif distribution == 3:
        X = np.random.exponential(1, (n,m,1))
    elif distribution == 4:
        X = np.random.exponential(1, (n,m,3))
    elif distribution == 5:
        X = np.random.uniform(-1,1, (n,m,3))
    elif distribution == 6:
        X = np.random.normal(0,1,(n,m,3))
    elif distribution == 7:
        """
        lomax distribution, f(x)=a/(1+x)^(a+1)
        """
        X = np.random.pareto(4, (n,m,1))
    elif distribution == 8:
        X = np.random.pareto(4, (n,m,3))     
    return X

def generate_randdiv(n, m, distribution):
    N = n * m
    if distribution == 1:
        X = np.random.uniform(-1,1,(N, 1))
    elif distribution == 2:
        X = np.random.normal(0,1, (N, 1))
    inds = np.random.choice(np.arange(1, N), size = n-1, replace = False)
    inds = np.sort(inds)
    D = []
    m_vec = []
    for i in range(n):
        if i == 0:
            u = np.mean(X[:inds[0]], axis = 0)
            mi = inds[0]
        elif i == n - 1:
            u = np.mean(X[inds[n-2]:], axis = 0)
            mi = N - inds[n-2]
        else:
            u = np.mean(X[inds[i-1]:inds[i]], axis = 0)
            mi = inds[i] - inds[i-1]
        D.append(u)
        m_vec.append(mi)
    D = np.vstack(D)
    m_vec = np.array(m_vec)
    return D, m_vec

def generate_randm(n,N,gamma,distribution):
    D = []
    c = np.array([int(N*(i/n)**gamma) for i in range(n+1)])
    m_vec = [max(c[i+1]-c[i], 1) for i in range(n)]
    for i in range(n):
        mi = m_vec[i]
        if distribution == 1:
            X = np.random.uniform(-1,1,(mi, 1))
        elif distribution == 2:
            X = np.random.normal(0,1,(mi,1))
        elif distribution == 3:
            X = np.random.exponential(1, (mi,1))
        D.append(np.mean(X, axis = 0))
    D = np.vstack(D)
    m_vec = np.array(m_vec)
    return D, m_vec
    