
import os
import sys
import numpy as np
import torch
import math
import argparse
import pickle
import logging

from sklearn.linear_model import LinearRegression, Ridge, LogisticRegression
from sklearn.metrics import accuracy_score

from generate_data import DataGenerator


sys.path.append('../')
from utils.logger import CompleteLogger
from utils.logger import Archiver
from utils.tools import combine_envs, combine_envs_np

from ACLib.base import DataSet
from ACLib.RAN import UniformSamplingSelector
from ACLib.CoreSet import CoreSetSelector

from PULib.pul import PU_SL
from PULib.prior_estimate import KernelPriorEstimator

from ILLib.ERM import Regression
from ILLib.ERM import ERM
from ILLib.IRM import IRM
from ILLib.IGA import IGA
from ILLib.HRM import HRM
from ILLib.EIIL import EIIL

parser = argparse.ArgumentParser(description='source')
parser.add_argument('--mode', type=str, default='train', help='train / test')
parser.add_argument('--ac', type=str, default='source_only', choices=['source_only', 'RAN', 'CoreSet', 'AADA', 'DBAL', 'CLUE'])
parser.add_argument('--ac-ratio', type=float, default=0.02)
parser.add_argument('--ac-round', type=int, default=5)
parser.add_argument('--pu', type=str, default='none', choices=['none', 'PU', 'PUI'])
parser.add_argument('--rep', type=int, default=10, help='random seeds: [0,rep)')
parser.add_argument('--ln', type=str, default='ERM', choices=['ERM', 'IRM', 'IGA', 'HRM', 'EIIL'])

parser.add_argument('--unbalanced', action='store_true')

parser.add_argument('--source-r', type=float, default=0.9)

args = parser.parse_args()

source_r = args.source_r

rs = [0.9, 0.7, 0.5, 0.3, 0.1]
ns = [1000, 1000, 1000, 1000, 1000]
num_env = 5

env_name = rs


log = 'logs'
log = os.path.join('logs', 'source_' + str(source_r))

if args.unbalanced:
    log = log + '_unbalanced'
    ns = [2500, 625, 625, 625, 625]
    
archiver = Archiver()

if not os.path.exists(log):
    os.mkdir(log)


METHOD = args.ac
if args.pu != 'none':
    METHOD = args.pu + '-' + METHOD
if args.ln != 'ERM':
    METHOD = METHOD + '-' + args.ln
    
log = os.path.join(log, METHOD)   
if not os.path.exists(log):
    os.mkdir(log)


def ERM_train(data):
    X, y, year = data
    
    model = LogisticRegression(fit_intercept=False).fit(X, y.reshape(-1))

    return model


def validate(env_data, model, verbose=False):

    acc_list = []
    num_env = len(env_data)
    
    for idx in range(num_env):
        
        X_i, y_i = env_data[idx]
        # X_i, y_i = np.asarray(X_i), np.asarray(y_i)
        
        # pred_i = model.predict(X_i)
        # acc_i = accuracy_score(y_i, pred_i)
        
        acc_i = model.score(X_i, y_i)
        
        acc_list.append(acc_i)
    
    for e_i, acc_i in zip(range(num_env), acc_list):
        print('env   [ri: {}]:  {} '.format(env_name[e_i], acc_i ))
    
    # pred_o = model.predict(X_all)
    # acc_o = accuracy_score(y_all, pred_o)
    
    X_all, y_all = combine_envs(env_data)
    acc_o = model.score(X_all, y_all)
    acc_mean, acc_min = np.mean(acc_list), np.min(acc_list)
    print('overall: {} '.format(acc_o ))
    print('average: {} '.format(acc_mean ))
    print('worst: {} '.format(acc_min ))
    
    # print(model.predict(X_all))
    # print(y_all)
    # exit(0)
    return acc_list, acc_o, acc_mean, acc_min

def pu_learning(X_pos, X_ul, prior=None):
    
    X_pos = X_pos.numpy()
    X_ul = X_ul.numpy()
    
    if prior is None:
        mpe_helper = KernelPriorEstimator()
        estimated_prior = mpe_helper.estimate(X_pos, X_ul)
        print('estimated prior: ', estimated_prior)
    else:
        estimated_prior = prior
        print('prior: ', estimated_prior)
    
    pul_helper = PU_SL(prior=estimated_prior,
                           n_fold=5,
                           sigma_list=[0.001, 0.01, 0.1, 1, 10],
                           lambda_list=None,
                           model='lm')
    
    clf = pul_helper.fit(X_pos, X_ul)
    
    pred = clf(X_ul)
    prob = 1. - (1. / (1. + np.exp(-pred)))
    return prob, estimated_prior

def active_adaptation(model, source_data, test_envs, seed):
    X, y = source_data
    
    ul_X, ul_y = combine_envs(test_envs)
    
    
    if args.ac == 'RAN':
        selector = UniformSamplingSelector(dataset=DataSet(ul_X, ul_y), seed=seed, ac_type=args.ac)
    elif args.ac == 'CoreSet':
        selector = CoreSetSelector(dataset=DataSet(ul_X, ul_y), seed=seed, ac_type=args.ac)
    else:
        raise Exception('Not Implemented...')
    
    already_selected_list = []
    rest_index_list = list(range(len(ul_X)))
    num_query = int(len(ul_y) * args.ac_ratio)
    
    feature_mask = torch.ones(X.shape[1]).long()
    
    prior = None
    for i in range(args.ac_round):

        print(already_selected_list)
        
        if args.pu == 'PU':
            
            if i == 0:
                pu_score, prior = pu_learning(X, ul_X, prior)
                
            score = pu_score[rest_index_list]
            # print(score)
            score = score / score.sum()
            selected_id, _ = selector.select_batch_(already_selected=already_selected_list, N=num_query, weight=score)
        elif args.pu == 'PUI':
            
            X_in = X * feature_mask
            ul_X_in = ul_X * feature_mask
            
            pu_score, prior = pu_learning(X_in, ul_X_in, prior)
            
            score = pu_score[rest_index_list]
            
            if args.ac == 'CoreSet':
                selector.dataset=DataSet(ul_X_in, ul_y)
            
            selected_id, _ = selector.select_batch_(already_selected=already_selected_list, N=num_query, weight=score)
        else:
            selected_id, _ = selector.select_batch_(already_selected=already_selected_list, N=num_query)
            
        already_selected_list = already_selected_list + selected_id
        rest_index_list = list(set(rest_index_list) - set(selected_id))
        
        print('query: ', len(selected_id), ' samples')
        archiver.save_ac_samples(selected_id, 'ac_{}.txt'.format(i))
        
        QX, Qy = ul_X[already_selected_list,:], ul_y[already_selected_list]
        
        envs = [(X, y), (QX, Qy)]
        if args.ln == 'ERM':
            # X_np, y_np = combine_envs_np(envs)
            X_cur, y_cur = combine_envs(envs)
            model = model.train([(X_cur, y_cur)], epochs=10000, lr=1e-3, verbose=False)
            
        elif args.ln == 'IRM':
            if i == 0:
                model = IRM(input_dim=X.shape[1], output_dim=1, lam=100., type='classification').set_model(model.model)
            model.train(envs, epochs=10000, lr=1e-5, verbose=False)
            # mask = model.generate_mask([0.90, 0.95, 0.98, 0.99])
            feature_mask = model.generate_mask([0.95]).long()
        elif args.ln == 'IGA':
            if i == 0:
                model = IGA(input_dim=X.shape[1], output_dim=1, lam=100., type='classification').set_model(model.model)
            model.train(envs, epochs=10000, lr=1e-5, verbose=False)
            # mask = model.generate_mask([0.90, 0.95, 0.98, 0.99])
            feature_mask = model.generate_mask([0.95]).long()
            
        else:
            raise Exception('Not Implemented..')
        validate(test_envs, model, verbose=True)
        
        archiver.save_model(model, 'model_{}.pkl'.format(i+1))
    
    
    return model


def main(seed):
    
    data_gen = DataGenerator(dim=10, pv=3, seed=seed)
    X_train, y_train = data_gen.generate_env(r=source_r, n=1000, y_encode=False)
    
    
    test_envs = data_gen.generate_envs(rs=rs, ns=ns, y_encode=False)
            
    archiver.set_path(log, seed)
    print('seed: {}     path: {}'.format(seed, archiver.basic_path))
    
    if args.mode == 'test':
        model = archiver.load_model('final_model.pkl')

        acc_list, acc_o, acc_mean, acc_min = validate(test_envs, model)
        return acc_list, acc_o, acc_mean, acc_min

    if args.ln == 'ERM' or args.ln == 'IRM' or args.ln == 'IGA':
        model = ERM(input_dim=X_train.shape[1], output_dim=1, type='classification')
        model = model.train([(X_train, y_train)], epochs=10000, lr=1e-3, verbose=False)
    elif args.ln == 'HRM':
        
        ERM_model = ERM(input_dim=X_train.shape[1], output_dim=1, type='classification')
        ERM_model = ERM_model.train([(X_train, y_train)], epochs=10000, lr=1e-3, verbose=False)
        
        model = HRM(X_train, y_train, input_dim=X_train.shape[1], output_dim=1, lam=100., type='classification')
        model.set_model(ERM_model.model)
        
        model.solve(iters=5, epochs=10000, lr=1e-5)
    elif args.ln == 'EIIL':
        
        
        ERM_model = ERM(input_dim=X_train.shape[1], output_dim=1, type='classification')
        ERM_model = ERM_model.train([(X_train, y_train)], epochs=10000, lr=1e-3, verbose=False)
        model = EIIL(input_dim=X_train.shape[1], output_dim=1, lam=100., type='classification')
        model.set_model(ERM_model.model)
        
        model.solve([(X_train, y_train)], epochs=10000, lr=1e-5)
    else:
        raise Exception('Not Implemented...')

    archiver.save_model(model, 'source_model.pkl')
    validate(test_envs, model, verbose=True)
    
    if args.ac != 'source_only':
        model = active_adaptation(model, (X_train, y_train), test_envs, seed)
        
    archiver.save_model(model, 'final_model.pkl')
    
    print('final evaluation:')
        
    acc_list, acc_o, acc_mean, acc_min = validate(test_envs, model)
    return acc_list, acc_o, acc_mean, acc_min
        

seeds = range(args.rep)

if __name__ == '__main__':
    
    records = []
    records_overall = []
    records_average = []
    records_worst = []
    logger = CompleteLogger(log, args.mode)
    
    for seed in seeds:
        acc_list, acc_o, acc_mean, acc_min  = main(seed)
        records.append(acc_list)
        records_overall.append(acc_o)
        records_average.append(acc_mean)
        records_worst.append(acc_min)
        print('     ---------------------------------    ')
    
    print('seeds: ', seeds)
    records = np.asarray(records)
    records_overall = np.asarray(records_overall)
    records_average = np.asarray(records_average)
    records_worst = np.asarray(records_worst)
    
    mean_acc = np.mean(records, axis=0)
    std_acc = np.std(records, axis=0)
    for e_i in range(num_env):
        print('env   [{}]:  {} +/- {}  '.format(env_name[e_i], mean_acc[e_i] , std_acc[e_i]))
        
    print('overall: {} +/- {}  '.format(np.mean(records_overall) , np.std(records_overall)))
    print('average: {} +/- {}  '.format(np.mean(records_average) , np.std(records_average)))
    print('worst: {} +/- {}  '.format(np.mean(records_worst) , np.std(records_worst)))
    logger.close()
