from reg_min_norm_solver import MinNormSolver
import numpy as np
import pickle as pkl

regularization = 'l1'

class SqLoss(object):
    def loss(x, center):
        return np.linalg.norm(x-center)**2/2
    
    def grad(x, center):
        return x-center

def circle_point(arc):
    return np.array([np.cos(arc), np.sin(arc)])

def PSG(x, a, b):
    v1, v2 = a-x, b-x
    return MinNormSolver.cal_min_norm([v1, v2])

class Circle(object):
    def __init__(self, period=10, r=1, arc0=0, mode='normal', clockwise=True):
        self.inc = np.pi*2/period
        self.period = period
        self.arc = arc0
        self.t = 0
        self.r = r
        if clockwise:
            self.clockwise = 1
        else:
            self.clockwise = -1
        self.mode = mode
    
    def next_point(self):
        np.random.seed(self.t)
        if self.mode == 'normal':
            self.arc += np.random.normal(self.inc, self.period**-.5)*self.clockwise
        else:
            self.arc += np.random.rand()*self.inc*2*self.clockwise
        self.t += 1
        return circle_point(self.arc)*self.r

sequences = [(10, 20, 'normal', True, True),
             (20, 10, 'normal', True, True)]

arc = 1
r1, r2 = 1, 1

NUM_ROUNDS = 10000
NUM_ROUNDS = [3000, 7000]
NUM_RECORD_ROUNDS = 100

configs = {'OMGMD':(.5, .6), 'linear opt.':(.5, .0), 'linear 1':(.3, .0), 'linear 2':(.7, .0), 'min-norm':(.5, .9)}

beta_cands = np.linspace(0, 1, 11, endpoint=True)
gamma_cands = np.linspace(0, 1, 11, endpoint=True)
lr_cands = np.linspace(.1, 2, 20, endpoint=True)

names = '10-20-gearshift-TT'

loss_rec = {}

for alg, params in configs.items():
    gamma0, beta = params
    print('Processing', alg, gamma0, beta)
    key = alg
    
    loss_hist = []
    loss_record = []
    
    for lr in lr_cands:
        loss_1, loss_2, loss = [], [], []
        l1, l2, l = .0, .0, .0
        
        gamma_prev = gamma0
        
        x = np.zeros(2)
        count = 0
        
        for ii, setup in enumerate(sequences):
            period_1, period_2, mode, clockwise_1, clockwise_2 = setup
            circle1 = Circle(period_1, r1, 0, mode, clockwise_1)
            circle2 = Circle(period_2, r2, arc, mode, clockwise_2)
            
            for t in range(NUM_ROUNDS[ii]):
                y1, y2 = circle1.next_point(), circle2.next_point()
                l1 += SqLoss.loss(x, y1)
                l2 += SqLoss.loss(x, y2)
                l += PSG(x, y1, y2)
                
                g1 = SqLoss.grad(x, y1)
                g2 = SqLoss.grad(x, y2)
                scales = MinNormSolver.find_min_norm_element_l1([g1, g2], gamma_prev, beta)
                gamma_prev = scales[0]
                
                x = x - lr*(scales[0]*g1+scales[1]*g2)
                
                if (t-1)%NUM_RECORD_ROUNDS == 0:
                    loss_1.append(l1/(count+1))
                    loss_2.append(l2/(count+1))
                    loss.append(l/(count+1))
                
                count += 1
            
        loss_record.append((lr, loss_1, loss_2, loss))
        loss_hist.append(loss[-1])
    
    k = loss_hist.index(min(loss_hist))
    print(min(loss_hist))
    loss_rec[key] = loss_record[k]

with open('rsl_test/'+'%s.pkl'%(names), 'wb') as f:
    pkl.dump(loss_rec, f)

