# /usr/bin/env python
# -*- coding: utf-8 -*-

import json as js
import copy
import os
from os.path import join
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm, trange
import numpy as np
import random

from dataset.wrapper import get_dataset
from online.models.wrapper import get_classifier_alg
from online.estimator.wrapper import get_estimator_alg
from online.utils.risk import *
from online.estimator.get_weights import weights_estimation

from utils.logger import MyLogger
from utils.argparser import argparser
from utils.tools import Timer

from offline_training import offline_train

import time
timer = Timer()

import warnings
warnings.filterwarnings('ignore')


def write(writer, info, t):
    for k, v in info.items():
        writer.add_scalar(k, v, t)


def set_cpu_num(cpu_num):
    os.environ['OMP_NUM_THREADS'] = str(cpu_num)
    os.environ['OPENBLAS_NUM_THREADS'] = str(cpu_num)
    os.environ['MKL_NUM_THREADS'] = str(cpu_num)
    os.environ['VECLIB_MAXIMUM_THREADS'] = str(cpu_num)
    os.environ['NUMEXPR_NUM_THREADS'] = str(cpu_num)
    torch.set_num_threads(cpu_num)
    torch.set_num_interop_threads(cpu_num)


def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True


def evaluation(model, data, label):
    pred = model.predict(data)
    error_cnt = (pred.view_as(label) != label).sum().item()
    error_avg = error_cnt / len(label)
    accuracy = 1. - error_avg

    return error_avg, accuracy


def run(cfgs, train_set, test_set, data_info, estimator, classifier,
        estimator_algorithm, classifier_algorithm, device='cuda', writer=None):

    record = []
    accumulate_estimate_loss, accumulate_error = 0, 0
    T = cfgs['round']
    erm_num = cfgs['Model']['Classifier']['kwargs']['erm_num']
    stop_erm = float(cfgs['Model']['Classifier']['kwargs']['stop_erm'])
    init_classifier = copy.deepcopy(classifier.model).to(device)

    loss_cfgs = cfgs['Model']['Classifier']['kwargs'].get('loss', {})
    loss_name = loss_cfgs.get('name', 'weighted_loss')
    nn_loss = loss_cfgs.get('nn_loss', False) or \
              cfgs['Model']['Classifier']['kwargs'].get('nn_loss', False)
    loss_func = eval(loss_name)(device=device, nn_loss=nn_loss)
    loss_func.set_class_num(data_info['cls_num'])

    time_helper = Timer()
    time_helper.tik()

    if cfgs['Model']['type'] == 'Linear':
        tr_rep = train_set.data.detach().float().to(device)
        tr_label = torch.tensor(train_set.label).long()
    else:
        all_reps = None
        all_labels = []
        for i in trange(len(train_set), desc='Get All Train Representation:'):
            x, y, _ = train_set.__getitem__(i)
            x = x.unsqueeze(0).detach().float().to(device)
            rep = classifier.model(x, mode='representation').detach().float().to(
                device)  # finished: fix the representation and do not update in learning
            all_labels.append(y)
            if all_reps is None:
                all_reps = rep
            else:
                all_reps = torch.cat((all_reps, rep), 0)
        tr_rep = all_reps
        tr_label = torch.tensor(all_labels).long()

    for t in tqdm(range(T)):

        if classifier_algorithm == 'FIX':
            pass
        else:
            classifier.reinit(init_classifier)

        te_rep, te_label, _ = test_set.__getitem__(t)
        te_rep, te_label = torch.tensor(te_rep).float().to(device), torch.tensor(te_label).long().to(device)

        if cfgs['Model']['type'] == 'Linear':
            pass
        else:
            te_rep = classifier.model(te_rep, mode='representation').detach().float().to(device)

        if classifier_algorithm == 'FIX':
            pass
        else:
            weights, est_loss = weights_estimation(estimator, estimator_algorithm, tr_rep, tr_label, te_rep)
            loss_func.set_weights(weights.float())
            classifier.set_func(loss_func)

            last_loss = 1e9
            for iter in range(erm_num):
                loss_ = classifier.parameters_update()
                if (last_loss - loss_) < stop_erm:
                    break
                last_loss = loss_

        instant_error, _ = evaluation(classifier, te_rep, te_label)
        accumulate_error += instant_error
        accumulate_avg_error = accumulate_error / (t + 1)

        if writer is not None:
            res_info = {
                'Error/1-Instant Error': instant_error,
                'Error/2-Average Error': accumulate_avg_error,
            }

            write(writer, res_info, t)
            for k, v in res_info.items():
                res_info[k] = v.item() if isinstance(v, torch.Tensor) else v
            record.append(res_info)

        if t % cfgs.get('log_interval', 100) == 0:
            time_helper.tok('{} rounds'.format(t))
            print(
                '\n[Time {}] Instant Error: {}, Average Error: {}'.format(t, instant_error, accumulate_avg_error))

    return record


if __name__ == "__main__":
    cfgs = argparser()
    device = cfgs.get('device')
    cpu_num = cfgs.get('cpu_num')
    set_cpu_num(cpu_num)
    setup_seed(cfgs['random_seed'])
    rng = np.random.default_rng(cfgs['random_seed'])

    writer = None
    if cfgs['Model'].get('write', True):
        writer = SummaryWriter(join(cfgs['output'],
                                    'runs_{}'.format(time.strftime("%a_%b_%d_%H:%M:%S",
                                                                   time.localtime()))))
    print(join(cfgs['output'],
               'runs_{}'.format(time.strftime("%a_%b_%d_%H:%M:%S",
                                              time.localtime()))))

    logger = MyLogger('{}/{}'.format(cfgs['output'], 'log.txt'))
    logger.info(str(cfgs['output']))

    train_set, test_set, data_info = get_dataset(
        name=cfgs['Data']['name'],
        cfgs=cfgs['Data']['kwargs'],
        rng=rng
    )

    if 'path' in cfgs['Model']:
        init_model = torch.load(cfgs['Model']['path'], map_location=device)
        estimator_model = init_model
        classifier_model = init_model
    else:
        init_model = offline_train(cfgs, data_info, train_set, writer, logger, device)
        estimator_model = init_model
        classifier_model = init_model

    classifier = get_classifier_alg(cfgs['Model']['Classifier'], classifier_model,
                                    train_set, device, rng, data_info)
    estimator = get_estimator_alg(cfgs['Model']['Estimator'], estimator_model,
                                  train_set, device, rng, data_info)
    classifier_algorithm = cfgs['Model']['Classifier'].get('algorithm', 'FIX')
    estimator_algorithm = cfgs['Model']['Estimator'].get('algorithm', 'Accous')

    record = run(cfgs, train_set, test_set, data_info,
                 estimator, classifier, estimator_algorithm, classifier_algorithm,
                 device=device, writer=writer)

    with open(join(cfgs["output"], 'result.json'), 'w') as fw:
        js.dump(record, fw, indent=4)
