from min_norm_solvers_simple import MinNormSolver
import numpy as np
import pickle as pkl

y_dict = {'1':1.0, '+1':1.0, '2':0.0, '-1':0.0}

class SqLoss():
    def loss(x):
        return np.dot(x,x)/2
    
    def grad(x):
        return x

class CELoss():
    def loss(x, xi, y):
        h = 1.0 / (1.0+np.exp(-np.dot(x, xi)))
        return -(y * np.log(h+1e-8) + (1-y) * np.log(1-h+1e-8))
    
    def grad(x, xi, y):
        h = 1.0 / (1.0+np.exp(-np.dot(x, xi)))
        return (h-y)*xi
    
    def evaluate(x, xi, y):
        h = 1.0 / (1.0+np.exp(-np.dot(x, xi)))
        return -(y * np.log(h+1e-8) + (1-y) * np.log(1-h+1e-8)), float(np.abs(y-h)<.5)

def gen_instance(line):
    r = line.strip().split()
    ry = r[0]
    
    y = y_dict[ry]
        
    x = np.zeros(N)
    x[0] = 1.0
    for rr in r[1:]:
        pos, val = rr.split(':')
        if int(pos) is not 4:
            x[int(pos)] = float(val)
    
    return x, y

def test_evaluate(x):
    loss, correct = .0, .0
    for xi, y in test_set:
        l, c = CELoss.evaluate(x, xi, y)
        loss += l
        correct += c
    return loss/NUM_TEST_INSTANCES, correct/NUM_TEST_INSTANCES

dataname = 'covtype'
N_dict = {'covtype':55, 'a9a':123}
NUM_RECORD_ROUNDS_dict = {'covtype':5000, 'a9a':300}

# Params for covtype
params_dict = {'linear optimal':(.0, .1), 'ours':(.2, .002), 'ours coarse':(.2, .0)}
params_dict = {'linear optimal':(.0, .1), 'ours':(.2, .0)}
N = N_dict[dataname] + 1

if dataname == 'covtype':
    #dataset = '../dataset/covtype/covtype.sorted'
    dataset = '../dataset/covtype/covtype.binary.scale'
else:
    dataset = '../dataset/a9a/a9a'
    
with open(dataset, 'r') as f:
    raw = f.readlines()
    
# random.shuffle(raw)
    
NUM_TEST_INSTANCES = 1000
test_set = [gen_instance(line) for line in raw[-NUM_TEST_INSTANCES:]]
train_set_raw = raw[:-NUM_TEST_INSTANCES]

NUM_ROUNDS = len(raw)
NUM_RECORD_ROUNDS = NUM_RECORD_ROUNDS_dict[dataname]

D = 100

# covtype
if dataname == 'covtype':
    lr_cands = [1.0, 1.05, 1.1, 1.15, 1.2, 1.25]

# a9a
if dataname == 'a9a':
    lr_cands = [.01, .012, .015, .018, .02, .025, .03]

loss_curves_1, loss_curves_2 = [], []
acces = []
test_losses, test_acces = [], []
loss_records = {}

for alg, params in params_dict.items():
    beta, gamma0 = params

    loss_record, loss_rec_1, loss_rec_2 = [], [], []
    acc_rec = []
    test_loss_rec, test_acc_rec = [], []
    
    name0 = 'beta=%.2f-gamma=%.3f'%(beta, gamma0)
    
    loss_curve_1, loss_curve_2 = [], []
    
    for lr in lr_cands:
        name = 'beta=%.2f-gamma=%.3f-lr=%.3f'%(beta, gamma0, lr)
        
        loss_1, loss_2 = [], []
        acc = []
        l1, l2 = .0, .0
        cor = .0
        
        test_loss, test_acc = [], []
        
        gamma_prev = gamma0
        
        np.random.seed(1234)
        x = np.random.rand(N)/np.sqrt(N)*2
            
        for t in range(NUM_ROUNDS):
            xi, y = gen_instance(raw[t])
            
            l1 += SqLoss.loss(x)
            l_tmp, c_tmp = CELoss.evaluate(x, xi, y)
            l2 += l_tmp
            cor += c_tmp
            
            g1 = SqLoss.grad(x)
            g2 = CELoss.grad(x, xi, y)
            
            scales = MinNormSolver.find_min_norm_element_l1([g1, g2], gamma_prev, beta)
            gamma_prev = scales[0]
            
            x = x - lr*(scales[0]*g1+scales[1]*g2)
            
            x_norm = np.linalg.norm(x)
            if x_norm > D:
                x = x / x_norm * D
            
            if (t+1)%NUM_RECORD_ROUNDS == 0:
                loss_1.append(l1/(t+1))
                loss_2.append(l2/(t+1))
                acc.append(cor/(t+1))
                
                l, c = test_evaluate(x)
                test_loss.append(l)
                test_acc.append(c)
            
        loss_record.append((lr, loss_1, loss_2, acc, test_loss, test_acc))
        loss_rec_1.append(loss_1[-1])
        loss_rec_2.append(loss_2[-1])
        
        acc_rec.append(acc[-1])
        test_loss_rec.append(test_loss[-1])
        test_acc_rec.append(test_acc[-1])
        
        print(name, loss_rec_1[-1], loss_rec_2[-1], acc[-1], test_loss[-1], test_acc[-1])
        
        loss_curve_1.append(loss_rec_1[-1])
        loss_curve_2.append(loss_rec_2[-1])
    
    loss_curves_1.append(loss_curve_1)
    loss_curves_2.append(loss_curve_2)
    
    acces.append(acc_rec)
    test_losses.append(test_loss_rec)
    test_acces.append(test_acc_rec)
    
    loss_records[alg] = loss_record
    
    with open('rsl-%s-test/%s.pkl'%(dataname, name0), 'wb') as f:
        pkl.dump(loss_record, f)
    