
import os
import sys
import numpy as np
import torch
import math
import argparse
import pickle
import logging
import pandas as pd

from sklearn.linear_model import LinearRegression, Ridge, LogisticRegression
from sklearn.metrics import accuracy_score

from generate_data import DataGenerator


sys.path.append('../')
from utils.logger import CompleteLogger
from utils.logger import Archiver
from utils.tools import combine_envs, combine_envs_np

from ACLib.base import DataSet
from ACLib.RAN import UniformSamplingSelector
from ACLib.CoreSet import CoreSetSelector

from PULib.pul import PU_SL
from PULib.prior_estimate import KernelPriorEstimator

from ILLib.ERM import ERM
from ILLib.IRM import IRM
from ILLib.IGA import IGA
from ILLib.HRM import HRM
from ILLib.EIIL import EIIL

parser = argparse.ArgumentParser(description='source')
parser.add_argument('--mode', type=str, default='train', help='train / test')
parser.add_argument('--ac', type=str, default='source_only', choices=['source_only', 'RAN', 'CoreSet', 'DBAL', 'AADA', 'CLUE'])
parser.add_argument('--ac-ratio', type=float, default=0.02)
parser.add_argument('--ac-round', type=int, default=5)
parser.add_argument('--pu', type=str, default='none', choices=['none', 'PU', 'PUI'])
parser.add_argument('--rep', type=int, default=10, help='random seeds: [0,rep)')
parser.add_argument('--ln', type=str, default='ERM', choices=['ERM', 'IRM', 'IGA', 'HRM', 'EIIL'])

parser.add_argument('--unbalanced', action='store_true')

parser.add_argument('--source-r', type=float, default=2.0)

args = parser.parse_args()


source_c = args.source_r

rs = [3.0, 2.0, 1.5, -1.5, -2.0, -3.0]
ns = [500, 500, 500, 500, 500, 500]
num_env = 6


env_name = rs

log = 'logs'
log = os.path.join('logs', 'source_' + str(source_c))

if args.unbalanced:
    log = log + '_unbalanced'
    ns = [300, 1500, 300, 300, 300, 300]


archiver = Archiver()

if not os.path.exists(log):
    os.mkdir(log)


METHOD = args.ac
if args.pu != 'none':
    METHOD = args.pu + '-' + METHOD
if args.ln != 'ERM':
    METHOD = METHOD + '-' + args.ln
    
log = os.path.join(log, METHOD)   
if not os.path.exists(log):
    os.mkdir(log)


def ERM_train(data):
    X, y, year = data
    
    model = LogisticRegression(fit_intercept=False).fit(X, y.reshape(-1))

    return model


def validate(env_data, model, verbose=False):

    mse_list = []
    num_env = len(env_data)
    
    for idx in range(num_env):
        
        X_i, y_i = env_data[idx]
        
        acc_i = model.score(X_i, y_i)
        
        mse_list.append(acc_i)
    
    for e_i, acc_i in zip(range(num_env), mse_list):
        print('env   [ri: {}]:  {} '.format(env_name[e_i], acc_i ))
    
    X_all, y_all = combine_envs(env_data)
    mse_o = model.score(X_all, y_all)
    
    mse_ave, mse_max = np.mean(mse_list), np.max(mse_list)
    
    print('overall: {} '.format(mse_o ))
    print('average: {} '.format(mse_ave ))
    print('worst: {} '.format(mse_max ))
    
    return mse_list, mse_o, mse_ave, mse_max

def pu_learning(X_pos, X_ul, prior=None):
    
    X_pos = X_pos.numpy()
    X_ul = X_ul.numpy()
    
    if prior is None:
        mpe_helper = KernelPriorEstimator()
        estimated_prior = mpe_helper.estimate(X_pos, X_ul)
        print('estimated prior: ', estimated_prior)
    else:
        estimated_prior = prior
        print('prior: ', estimated_prior)
    
    pul_helper = PU_SL(prior=estimated_prior,
                           n_fold=5,
                           sigma_list=[0.001, 0.01, 0.1, 1, 10],
                           lambda_list=None,
                           model='lm')
    
    clf = pul_helper.fit(X_pos, X_ul)
    
    pred = clf(X_ul)
    prob = 1. - (1. / (1. + np.exp(-pred)))
    return prob, estimated_prior

def active_adaptation(model, source_data, test_envs, seed):
    
    X, y = source_data
    
    ul_X, ul_y = combine_envs(test_envs)
    
    
    if args.ac == 'RAN':
        selector = UniformSamplingSelector(dataset=DataSet(ul_X, ul_y), seed=seed, ac_type=args.ac)
    elif args.ac == 'CoreSet':
        selector = CoreSetSelector(dataset=DataSet(ul_X, ul_y), seed=seed, ac_type=args.ac)
    else:
        raise Exception('Not Implemented...')
    
    already_selected_list = []
    rest_index_list = list(range(len(ul_X)))
    num_query = int(len(ul_y) * args.ac_ratio)
    
    feature_mask = torch.ones(X.shape[1]).long()
    
    prior = None
    for i in range(args.ac_round):

        print(already_selected_list)
        
        if args.pu == 'PU':
            
            pu_score, prior = pu_learning(X, ul_X[rest_index_list,:], prior=prior)
            pu_score = pu_score / pu_score.sum()
            selected_id, _ = selector.select_batch_(already_selected=already_selected_list, N=num_query, weight=pu_score)
        elif args.pu == 'PUI':
            
            X_in = X * feature_mask
            ul_X_in = ul_X * feature_mask
            
            pu_score, prior = pu_learning(X_in, ul_X_in, prior=prior)
             
            score = pu_score[rest_index_list]
            
            if args.ac == 'CoreSet':
                selector.dataset=DataSet(ul_X_in, ul_y)
            
            selected_id, _ = selector.select_batch_(already_selected=already_selected_list, N=num_query, weight=score)
            
        else:
            selected_id, _ = selector.select_batch_(already_selected=already_selected_list, N=num_query)
            
        already_selected_list = already_selected_list + selected_id
        rest_index_list = list(set(rest_index_list) - set(selected_id))
        
        print('query: ', len(selected_id), ' samples')
        archiver.save_ac_samples(selected_id, 'ac_{}.txt'.format(i))
        
        QX, Qy = ul_X[already_selected_list,:], ul_y[already_selected_list]
        
        envs = [(X, y), (QX, Qy)]
        if args.ln == 'ERM':
            X_cur, y_cur = combine_envs(envs)
            model = model.train([(X_cur, y_cur)], epochs=10000, lr=1e-3, verbose=False)
            
        elif args.ln == 'IRM':
            if i == 0:
                model = IRM(input_dim=X.shape[1], output_dim=1, lam=100., type='regression').set_model(model.model)
            model.train(envs, epochs=10000, lr=1e-5, verbose=False)
            feature_mask = model.generate_mask([0.95]).long()
        elif args.ln == 'IGA':
            if i == 0:
                model = IGA(input_dim=X.shape[1], output_dim=1, lam=200., type='regression').set_model(model.model)
            model.train(envs, epochs=10000, lr=1e-5, verbose=False)
            feature_mask = model.generate_mask([0.95]).long()
            
        else:
            raise Exception('Not Implemented..')
        validate(test_envs, model, verbose=True)
        
        archiver.save_model(model, 'model_{}.pkl'.format(i+1))
    
    
    return model


def main(seed):
    
    data_gen = DataGenerator(dim=10, pv=3, seed=seed)
    X_train, y_train = data_gen.generate_env(r=source_c, n=1000)
    test_envs = data_gen.generate_envs(rs=rs, ns=ns)
            
    archiver.set_path(log, seed)
    print('seed: {}     path: {}'.format(seed, archiver.basic_path))
    
    if args.mode == 'test':
        model = archiver.load_model('final_model.pkl')

        mse_list, mse_o, mse_ave, mse_max = validate(test_envs, model)
        return mse_list, mse_o, mse_ave, mse_max 

    if args.mode == 'test-weight':
        model = archiver.load_model('final_model.pkl')
        model.device = "cpu"
        weights = model.model.weight.data
        print(weights)
        mse_list, mse_o, mse_ave, mse_max = validate(test_envs, model)
        return mse_list, mse_o, mse_ave, mse_max, weights
    
    if args.ln == 'ERM' or args.ln == 'IRM' or args.ln == 'IGA':
        model = ERM(input_dim=X_train.shape[1], output_dim=1, type='regression')
        model = model.train([(X_train, y_train)], epochs=10000, lr=1e-3, verbose=False)
    elif args.ln == 'HRM':
        
        ERM_model = ERM(input_dim=X_train.shape[1], output_dim=1, type='regression')
        ERM_model = ERM_model.train([(X_train, y_train)], epochs=10000, lr=1e-3, verbose=False)
        
        model = HRM(X_train, y_train, input_dim=X_train.shape[1], output_dim=1, lam=200., type='regression')
        model.set_model(ERM_model.model)
        
        model.solve(epochs=10000, lr=1e-5)
        pass
    elif args.ln == 'EIIL':
        ERM_model = ERM(input_dim=X_train.shape[1], output_dim=1, type='regression')
        ERM_model = ERM_model.train([(X_train, y_train)], epochs=10000, lr=1e-3, verbose=False)
        
        model = EIIL(input_dim=X_train.shape[1], output_dim=1, lam=200., type='regression')
        model.set_model(ERM_model.model)
        
        model.solve([(X_train, y_train)], epochs=10000, lr=1e-5)
    
    archiver.save_model(model, 'source_model.pkl')
    validate(test_envs, model, verbose=True)
    
    
    if args.ac != 'source_only':
        model = active_adaptation(model, (X_train, y_train), test_envs, seed)
        
    archiver.save_model(model, 'final_model.pkl')
    
    print('final evaluation:')
        
    mse_list, mse_o, mse_ave, mse_max  = validate(test_envs, model)
    return mse_list, mse_o, mse_ave, mse_max
        

seeds = range(args.rep)

if __name__ == '__main__':
    
    records = []
    records_overall = []
    records_average = []
    records_worst = []
    logger = CompleteLogger(log, args.mode)
    
    record_weights = torch.Tensor()
    
    for seed in seeds:
        if args.mode == 'test-weight':
            mse_list, mse_o, mse_ave, mse_max, weights = main(seed)
            record_weights = torch.cat([record_weights, weights], dim=0)
        else:
            mse_list, mse_o, mse_ave, mse_max  = main(seed)
        records.append(mse_list)
        records_overall.append(mse_o)
        records_average.append(mse_ave)
        records_worst.append(mse_max)
        print('     ---------------------------------    ')
    
    print('seeds: ', seeds)
    records = np.asarray(records)
    records_overall = np.asarray(records_overall)
    records_average = np.asarray(records_average)
    records_worst = np.asarray(records_worst)
    
    mean_mse = np.mean(records, axis=0)
    std_mse = np.std(records, axis=0)
    for e_i in range(num_env):
        print('env   [{}]:  {} +/- {}  '.format(env_name[e_i], mean_mse[e_i] , std_mse[e_i]))
        
    print('overall: {} +/- {}  '.format(np.mean(records_overall) , np.std(records_overall)))
    print('average: {} +/- {}  '.format(np.mean(records_average) , np.std(records_average)))
    print('worst: {} +/- {}  '.format(np.mean(records_worst) , np.std(records_worst)))
    
    
    
    if args.mode == 'test-weight':
        record_weights = record_weights
        print('learned weights: ', record_weights.mean(dim=0))
        df = pd.DataFrame(columns=np.arange(10).tolist())
        
        record_weights = torch.abs(record_weights)
        record_weights[record_weights > 1.0] = 1.0
        
        df.append(pd.DataFrame(record_weights.numpy()))
        df.loc['mean'] = record_weights.mean(dim=0).tolist()
        df.loc['std'] = record_weights.std(dim=0).tolist()
        print(df)
        csv_path = os.path.join(log, 'weights.csv')
        df.to_csv(csv_path)
        print('save >>> ', csv_path)
        
    logger.close()
