# /usr/bin/env python
# -*- coding: utf-8 -*-

import torch

from online.estimator.meta import Hedge, AdaMLProd
from online.estimator.base import Base, ONS
from online.estimator.schedule import Schedule
from online.utils.domain import Ball, Simplex

from online.utils.loss_function import SquareLoss, LogisticLoss

import numpy as np

class OLR(Base):
    def __init__(self, cfgs, seed=None, **alg_kwargs):
        super(OLR, self).__init__(cfgs, seed=seed, **alg_kwargs)

        if cfgs is None:
            self.cfgs = {}
        else:
            self.cfgs = cfgs

        self.device = alg_kwargs['device']
        self._t = 0

        self.T = cfgs['T']
        self.R = cfgs['R']
        self.S = cfgs['S']
        self.dim = cfgs['dim']
        self.domain = eval(cfgs['domain'])(cfgs['dim'], cfgs['R'])

        self.epsilon_base = cfgs['epsilon']
        self.coff_eta_base = cfgs['coff_eta_base']
        self.eta_base = self.coff_eta_base * (1 + np.exp(self.R * self.S))
        print("eta_base", self.eta_base)
        self.coff_output = cfgs['coff_output']
        self.trancate = cfgs['weights_tranc']

        self.loss_ = torch.zeros(self.T)

        alg_kwargs['dim'] = self.dim
        alg_kwargs['eta_base'] = self.eta_base
        alg_kwargs['epsilon_base'] = self.epsilon_base

        self.model = ONS(cfgs=cfgs, seed=seed, **alg_kwargs)
        self.model.init_model(np.zeros(self.dim))

        self.ori_source_data = []
        self.data_logreg = []
        self.label_logreg = []

        self.evl_cnt = 0

    def update(self):
        self._t += 1

        loss = 0
        for i in range(len(self.label_logreg)):
            func = LogisticLoss(self.data_logreg[i], self.label_logreg[i]).func
            self.model.set_feature(self.data_logreg[i])
            self.model.set_label(self.label_logreg[i])
            self.model.set_func(func)
            _, _loss, _ = self.model.opt()
            loss += _loss
        loss /= len(self.label_logreg)

        return loss

    def predict(self):
        data = self.ori_source_data.to(self.device).to(torch.float32)
        output, _ = self.model.predict(data)

        return output

    def load_data(self, source_data, target_data):
        self.ori_source_data = source_data
        data_number = target_data.shape[0]

        idx = torch.randperm(source_data.shape[0])
        idx = idx[:data_number]
        source_data = source_data[idx]
        source_label = torch.ones(data_number).long()
        target_label = -1 * torch.ones(data_number).long()
        data_logreg = torch.cat((source_data, target_data), dim=0)
        label_logreg = torch.cat((source_label, target_label), dim=0)
        idx = torch.randperm(2 * data_number)
        data_logreg = data_logreg[idx]
        label_logreg = label_logreg[idx]

        self.data_logreg = data_logreg
        self.label_logreg = label_logreg

    def sigmoid(self, z):
        return 1.0 / (1.0 + np.exp(-z))

    def eval_test(self, model):
        acc_cnt = 0
        pred_list = []
        soft_pred_list = []
        for i in range(len(self.label_logreg)):
            output = np.dot(model, self.data_logreg[i])
            soft_pred = self.sigmoid(output)
            soft_pred_list.append(soft_pred)
            pred = 1 if (soft_pred > 0.5) else -1
            pred_list.append(pred)
            acc_cnt += (pred==self.label_logreg[i]).int()

        acc = acc_cnt / len(self.label_logreg)
        self.evl_cnt += acc_cnt

        return acc

    def estimate(self, source_data, target_data):

        self.load_data(source_data, target_data)
        self.eval_test(self.model.get_model())

        weights = self.predict()
        weights = torch.tensor(weights)
        weights = 1. / weights - 1.
        weights = weights ** self.coff_output
        weights = torch.minimum(torch.tensor(self.trancate * torch.ones(len(weights))), weights)

        self.loss_[self._t - 1] = self.update()

        return weights