# -*- coding: utf-8 -*-
"""
Created on Mon Feb 27 13:18:32 2023

@author: cvpr2024 11221
"""

import numpy as np
import os
import torch

from dataset_model import FeasDataset, ImageFolderWithIndex, MLP, get_augmentation, get_dataset, get_network
from utils import train_mlp, evaluation, get_output_emb, train_freeze_mlp, train_fine_tune

import argparse
import shutil

parser = argparse.ArgumentParser(description='')
parser.add_argument('--sampling_strategy', default='random', type=str,
                    help='Sampling strategy')
parser.add_argument('--al_budget', default='[200]*1', type=str,#[10000]*10
                    help='dims of classifier')
parser.add_argument('--expid', default='1_r50_test', type=str,
                    help='order of exps')
parser.add_argument('--outpath_base', default=None, type=str,
                    help='path of results')
parser.add_argument('--dataset_name', default='feas', type=str,
                    help='name of dataset [imagenet, feas, cifar10, imagenet100, vic_cape_howe]')
parser.add_argument('--dataset_path', default=None, type=str,
                    help='for fine-tune and freezing & mlp')
parser.add_argument('--selfmodel_path', default=None, type=str,
                    help='path of selfsup model')
parser.add_argument('--trainidx', default=None, type=str,
                    help='trainidx for vic_cape_howe dataset')
parser.add_argument('--testidx', default=None, type=str,
                    help='testidx for vic_cape_howe dataset')

parser.add_argument('--totfeas_path', default=None, type=str,
                    help='path of selfsup feas')
parser.add_argument('--totlabel_path', default=None, type=str,
                    help='path of trainset label') 
                    
parser.add_argument('--factor_std', default=0.25, type=float,
                    help='factor of cls std')
parser.add_argument('--clsstd_path', default=None, type=str,
                    help='path of cls std')
parser.add_argument('--totfeas_test_path', default=None, type=str,
                    help='path of selfsup feas testset')
parser.add_argument('--totlabel_test_path', default=None, type=str,
                    help='path of testset label')
parser.add_argument('--load_proj_weight', default=False, type=bool,
                    help='initialized classifier weights from projector')
parser.add_argument('--load_al_weight', default=False, type=bool,
                    help='initialized classifier weights from last Active learning round')

parser.add_argument('--train_eps', default=200, type=int,
                    help='# of training epoch')
parser.add_argument('--lr', default=0.1, type=float,
                    help='learning rate')
parser.add_argument('--cls_lr', default=0.1, type=float,
                    help='learning rate for classifier')
parser.add_argument('--momentum', default=0.9, type=float,
                    help='momentum')
parser.add_argument('--weight_decay', default=0, type=float,
                    help='weight_decay')
parser.add_argument('--nesterov', default=True, type=bool,
                    help='nesterov')
parser.add_argument('--milestone', default='60, 80', type=str,
                    help='learning rate schedule (when to drop lr by a ratio)')
parser.add_argument('--early_stop', default=200, type=int,
                    help='efficient AL baseline, early stop')
parser.add_argument('--freezelr', default=0.5, type=float,
                    help='lr in freeze stage')
parser.add_argument('--freeze_eps', default=10, type=int,
                    help='training eps of lp stage')
parser.add_argument('--ft_eps', default=120, type=int,
                    help='training eps of ft stage')

parser.add_argument('--network', default='res50', type=str,
                    help='[res18,res50,res50x2,res50x4]')

parser.add_argument('--batchsize_train', default=1024, type=int,
                    help='path of testset label')
parser.add_argument('--grad_accu', default=1, type=int,
                    help='num grad accum')
parser.add_argument('--batchsize_al_forward', default=1024*1, type=int,
                    help='path of testset label')
parser.add_argument('--batchsize_evaluation', default=1024*2, type=int,
                    help='path of testset label')
parser.add_argument('--classifier_dim', default='2048,512,16', type=str,
                    help='dims of classifier')

parser.add_argument('--training_mode', default=0, type=int,
                    help='0:MLP_proxy(ours), 1:freezing encoder and training classifier, 2:Fine-tuning, 3:LP-FT')
parser.add_argument('--classifier_type', default='MLP', type=str,
                    help='Linear or MLP')

parser.add_argument('--distributed_training', default=False, type=bool,
                    help='using nn.dataparaller')
parser.add_argument('--num_workers', default=0, type=int,
                    help='for dataloader')

parser.add_argument('--num_subset', default=320000, type=int,
                    help='AL, badge_partition_subset, the size of subset')
parser.add_argument('--num_partition', default=20, type=int,
                    help='AL, badge_partition_subset, the size of subset')

parser.add_argument('--alidx_path', default=None, type=str,
                    help='if a lblset is given')
parser.add_argument('--alidx_length', default=None, type=int,
                    help='how long the lblset is used in this AL pass')

### for ActiveFT 

parser.add_argument('--activeft_temperature', default=0.07, type=float, help='temperature for softmax')
parser.add_argument('--activeft_max_iter', default=100, type=int, help='max iterations')#300
parser.add_argument('--activeft_lr', default=0.001, type=float, help='learning rate')
parser.add_argument('--activeft_init', default='random', type=str, choices=['random', 'fps'])
parser.add_argument('--activeft_distance', default='euclidean', type=str, help='euclidean or cosine')
parser.add_argument('--activeft_scheduler', default='none', type=str, help='scheduler')
parser.add_argument('--activeft_balance', default=1.0, type=float, help='balance ratio')
parser.add_argument('--activeft_batch_size', default=20000, type=int, help='batch size for SGD')
parser.add_argument('--activeft_slice', default=10, type=int, help='size of slice to save memory')

hyperalidx = None

args = parser.parse_args()
if args.training_mode == 3:
    ftlr = args.lr
    clslr = args.cls_lr

if args.trainidx is not None:
    args.trainidx = np.load(args.trainidx)
if args.testidx is not None:
    args.testidx = np.load(args.testidx)

args.milestone = args.milestone.split(',')
args.milestone = [int(i) for i in args.milestone]

print(args.lr)
print(args.cls_lr)
print(args.expid)

indim_classifier, hiddim_classifier, outdim_classifier = args.classifier_dim.split(',')
indim_classifier, hiddim_classifier, outdim_classifier = int(indim_classifier), [int(hiddim_classifier)], int(outdim_classifier)


num_budget = eval(args.al_budget)#
num_al_itr = len(num_budget)

sampling_strategy = args.sampling_strategy
expid = args.expid

if sampling_strategy == 'badge' or sampling_strategy ==  'badge_partition' or sampling_strategy ==  'badge_partition_subset':    
    from samplings.sampling_strategy import kmeans_plus
    from utils import get_grad_embedding
    if sampling_strategy == 'badge_partition':
        from samplings.sampling_strategy import kmeans_plus_partition
elif sampling_strategy == 'entropy':
    from samplings.sampling_strategy import entropy_sampling
elif sampling_strategy == 'margin':
    from samplings.sampling_strategy import margin_sampling, margin_sampling_model
elif sampling_strategy == 'coreset' or sampling_strategy == 'coreset_self':
    from samplings.sampling_strategy import acquire_new_sample
elif sampling_strategy == 'cluster_margin':
    from samplings.sampling_strategy import cluster_margin_sampling, margin_sampling_model, initial_hac
elif sampling_strategy == 'learningloss':
    from samplings.sampling_strategy import est_loss, learning_loss_sampling
    from utils import train_mlp_wlearningloss
    from dataset_model import LossNet
    lossmodel = None
elif sampling_strategy == 'ActiveFT(al)' or sampling_strategy == 'ActiveFT(self)':
    totfeas = np.load(args.totfeas_path)
    if len(totfeas) > 100000:
        from samplings.ActiveFT_large import ActiveFT_sampling
    else:
        from samplings.ActiveFT import ActiveFT_sampling
    
elif sampling_strategy == 'confidence':
    from samplings.sampling_strategy import confidence_sampling_model

dataset = args.dataset_name

outpath = os.path.join(args.outpath_base, dataset) 
exp_name = dataset + '_' + sampling_strategy + '_exp' + str(expid) + '_training_strategy' + str(args.training_mode)
outpath = os.path.join(outpath, exp_name)
os.makedirs(outpath, exist_ok=True) 

#record configuration file
shutil.copy(os.path.join('.','train.py'), outpath)

selfmodel_path = args.selfmodel_path

# load feas,std
if args.training_mode == 0:
    totfeas = np.load(args.totfeas_path)
    totlabel = np.load(args.totlabel_path)
    if args.clsstd_path is not None:
        std = np.load(args.clsstd_path)
        std = std * args.factor_std
    else:
        std = None
        
    totfeas_test = np.load(args.totfeas_test_path)
    totlabel_test = np.load(args.totlabel_test_path)
    
    allset = FeasDataset(totfeas, totlabel, None)
    testset = FeasDataset(totfeas_test, totlabel_test, None)    
else:
    transform_train = get_augmentation(args, train = True)
    transform_test = get_augmentation(args, train = False)
    
    allset = get_dataset(args, transform_test, index = None, train = True )
    testset = get_dataset(args, transform_test, index = None, train = False )
    
all_loader = torch.utils.data.DataLoader(
    allset,
    batch_size = args.batchsize_al_forward,
    num_workers = args.num_workers,
    shuffle = False,
    drop_last = False
)
 
test_loader = torch.utils.data.DataLoader(
    testset,
    batch_size = args.batchsize_evaluation,
    num_workers = args.num_workers,
    shuffle = False,
    drop_last = False
)


totemb = None
totpre = None

if args.classifier_type == 'MLP':
    classifier = MLP(indim_classifier, hiddim_classifier, outdim_classifier)
elif args.classifier_type == 'Linear':
    classifier = torch.nn.Linear(indim_classifier,outdim_classifier)
else:
    raise NotImplementedError
    
totacc = []
tracc = []

###load model and initiliaze with self-sup weight 
checkpoint = torch.load(selfmodel_path, map_location=torch.device('cpu'))

model = get_network(args)
if args.training_mode != 0:
    
    #print(model)
    encoder_dict = model.state_dict()
    if args.network == 'res50':
        #state_dict = {k[7:]:v for k,v in checkpoint['online_backbone'].items() if k[7:] in encoder_dict.keys()}#byol
        state_dict = {k[27:]:v for k,v in checkpoint['state_dict'].items() if k[27:] in encoder_dict.keys()}#byol eman
    elif args.network == 'res18' or args.network == 'wrn288':
        state_dict = {k[9:]:v for k,v in checkpoint['state_dict'].items() if k[9:] in encoder_dict.keys()} 
    else:
        raise NotImplementedError
    encoder_dict.update(state_dict)
    model.load_state_dict(encoder_dict)
    
    model.fc = torch.nn.Identity()  


import time

s = time.time()
s0 = time.time()

if args.alidx_path is None:
    alidx = []
    uidx = [i for i in range(len(allset))]
else:
    alidx = np.load(args.alidx_path).tolist()
    alidx = alidx[:args.alidx_length]
    uidx = [i for i in range(len(allset)) if i not in alidx]

for alitr in range(num_al_itr):
    
    if (alitr == 0 and len(alidx) == 0) or (alitr > 0 and len(alidx) > 0):#not resume al 
    
        if args.sampling_strategy != 'coreset' and args.sampling_strategy != 'ActiveFT(al)':
            if args.training_mode == 0:
                allset = FeasDataset(totfeas[uidx,:], totlabel[uidx], None)     
            else:
                #allset = ImageFolderWithIndex(root = os.path.join(args.dataset_path, 'train'), indexs= uidx, transform = transform_test)
                allset = get_dataset(args, transform_test, index = uidx, train = True )
                
            all_loader = torch.utils.data.DataLoader(
                allset,
                batch_size = args.batchsize_al_forward,
                num_workers = args.num_workers,
                shuffle = False,
                drop_last = False
            )
            
    
        if sampling_strategy == 'entropy':
            if len(alidx) == 0:
                np.random.shuffle(uidx)
                alidx += uidx[:num_budget[alitr]] 
                uidx = list( set(uidx) - set(alidx) )
            else:
                totpre, _ = get_output_emb(all_loader, classifier, False, args, model)
                # np.save(os.path.join(outpath, 'totpre' + str(alitr) + '.npy'), totpre)
                alidx += entropy_sampling(totpre, uidx, num_budget[alitr])
                uidx = list( set(uidx) - set(alidx) )
                
        elif sampling_strategy == 'random':
            if hyperalidx is not None:
                alidx = hyperalidx[:np.sum(num_budget[:alitr+1])]
                uidx = list( set(uidx) - set(alidx) )
            else:
                np.random.shuffle(uidx)
                alidx += uidx[:num_budget[alitr]] 
                uidx = list( set(uidx) - set(alidx) )
            
        elif sampling_strategy == 'badge':
            if len(alidx) == 0:
                np.random.shuffle(uidx)
                alidx += uidx[:num_budget[alitr]] 
                uidx = list( set(uidx) - set(alidx) )
            else:
                emb = get_grad_embedding(classifier, all_loader, args, model)
                print(len(uidx), len(emb))
                newidx = kmeans_plus(uidx, emb, num_budget[alitr])
                alidx += newidx
                uidx = list( set(uidx) - set(alidx) )
                
        elif sampling_strategy == 'badge_partition':
            if len(alidx) == 0:
                np.random.shuffle(uidx)
                alidx += uidx[:num_budget[alitr]]
                uidx = list( set(uidx) - set(alidx) )
            else:
                np.random.shuffle(uidx)
                num_partition = int(len(uidx) / args.num_partition)
                for ipart in range(args.num_partition):
                    print('badge partition ', ipart)
                    allset = FeasDataset( totfeas[uidx[num_partition*ipart:num_partition*(ipart+1)],:], totlabel[uidx[num_partition*ipart:num_partition*(ipart+1)]], None)
                    all_loader = torch.utils.data.DataLoader(
                        allset,
                        batch_size = args.batchsize_al_forward,
                        num_workers = 8,
                        shuffle = False,
                        drop_last = False
                    )
                    emb = get_grad_embedding(classifier, all_loader, args, model)
                    if num_budget[alitr] % args.num_partition != 0:
                        print('wrong partition setting')
                    newidx = kmeans_plus([i for i in range(len(emb))], emb, int(num_budget[alitr]/args.num_partition))
                    alidx += np.array(uidx[num_partition*ipart:num_partition*(ipart+1)])[newidx].tolist()
                    uidx = list( set(uidx) - set(alidx) )
                    
        elif sampling_strategy == 'badge_partition_subset':
            if len(alidx) == 0:
                np.random.shuffle(uidx)
                alidx += uidx[:num_budget[alitr]] 
                uidx = list( set(uidx) - set(alidx) )
            else:
                np.random.shuffle(uidx)
                candidx = uidx[:args.num_subset]
                
                num_partition = int(len(candidx) / args.num_partition)
                for ipart in range(args.num_partition):
                    print('badge partition ', ipart)
                    allset = FeasDataset( totfeas[candidx[num_partition*ipart:num_partition*(ipart+1)],:], totlabel[candidx[num_partition*ipart:num_partition*(ipart+1)]], None)
                    all_loader = torch.utils.data.DataLoader(
                        allset,
                        batch_size = args.batchsize_al_forward,
                        num_workers = 8,
                        shuffle = False,
                        drop_last = False
                    )
                    emb = get_grad_embedding(classifier, all_loader, args, model)
                    if num_budget[alitr] % args.num_partition != 0:
                        print('wrong partition setting')
                    newidx = kmeans_plus([i for i in range(len(emb))], emb, int(num_budget[alitr]/args.num_partition))
                    alidx += np.array(candidx[num_partition*ipart:num_partition*(ipart+1)])[newidx].tolist()
                    uidx = list( set(uidx) - set(alidx) )
                
                
                
        elif sampling_strategy == 'margin':
            if len(alidx) == 0:
                np.random.shuffle(uidx)
                alidx += uidx[:num_budget[alitr]] 
                uidx = list( set(uidx) - set(alidx) )
            else:
                
                alidx += margin_sampling_model(uidx, all_loader, classifier, model, num_budget[alitr], args)
                uidx = list( set(uidx) - set(alidx) )
                
        elif sampling_strategy == 'confidence':
            if len(alidx) == 0:
                np.random.shuffle(uidx)
                alidx += uidx[:num_budget[alitr]] 
                uidx = list( set(uidx) - set(alidx) )
            else:
                alidx += confidence_sampling_model(uidx, all_loader, classifier, model, num_budget[alitr], args)
                uidx = list( set(uidx) - set(alidx) )
        
        elif sampling_strategy == 'ActiveFT(al)':
            if len(alidx) == 0:
                alidx = ActiveFT_sampling(totfeas, num_budget[alitr], alidx, args)
            else:
                _, alfeas = get_output_emb(all_loader, classifier, True, args, model)
                alidx = ActiveFT_sampling(alfeas, num_budget[alitr], alidx, args)
        
        elif sampling_strategy == 'ActiveFT(self)':
            if len(alidx) == 0:
                alidx = ActiveFT_sampling(totfeas, num_budget[alitr], alidx, args)
            else:
                alidx = alidx = ActiveFT_sampling(totfeas, num_budget[alitr], alidx, args)
        
        elif sampling_strategy == 'coreset':
            if len(alidx) == 0:
                np.random.shuffle(uidx)
                alidx += uidx[:num_budget[alitr]] 
                uidx = list( set(uidx) - set(alidx) )
            else:
                _, alfeas = get_output_emb(all_loader, classifier, True, args, model)
                np.save(os.path.join(outpath, 'alfeas' + str(alitr) + '.npy'), alfeas)
                label_feas = alfeas[alidx,:]
                uidx = list( set(uidx) - set(alidx) )
                unlabel_feas = alfeas[uidx,:]
                newidx = acquire_new_sample(num_budget[alitr], uidx, torch.from_numpy(label_feas), torch.from_numpy(unlabel_feas))
                alidx += newidx
                
        elif sampling_strategy == 'cluster_margin':
            if len(alidx) == 0:
                uidx = list( set(uidx) - set(alidx) )
                np.random.shuffle(uidx)
                alidx += uidx[:num_budget[alitr]] 
                uidx = list( set(uidx) - set(alidx) )
                # initial_center, initial_cluster = initial_hac(totfeas[uidx[-num_budget[alitr]*10:],:])
                initial_cluster = initial_hac(totfeas[uidx[-num_budget[alitr]*1:],:])
                initial_center = totfeas[uidx[-num_budget[alitr]*1:]]
            else:
                candidx = margin_sampling_model(uidx, all_loader, classifier, None, num_budget[alitr]*5, args)
                newidx, initial_cluster, initial_center = cluster_margin_sampling(candidx, totfeas[candidx,:], initial_center, initial_cluster, num_budget[alitr])
                alidx += newidx
                uidx = list( set(uidx) - set(alidx) )
                
        elif sampling_strategy == 'learningloss':
            if len(alidx) == 0:
                uidx = list( set(uidx) - set(alidx) )
                np.random.shuffle(uidx)
                alidx += uidx[:num_budget[alitr]] 
                uidx = list( set(uidx) - set(alidx) )
            else:
                estloss = est_loss(classifier, lossmodel, all_loader)
                newidx = learning_loss_sampling(uidx, estloss, num_budget[alitr])
                alidx += newidx
                uidx = list( set(uidx) - set(alidx) )
            
        # np.save(os.path.join(outpath, 'alidx' + str(alitr) + '.npy'), np.array(alidx))
        np.save(os.path.join(outpath, 'alidx.npy'), np.array(alidx))
    
        print('point 1 sample selection', time.time() - s)
        s = time.time()
    
    if args.training_mode == 0:
        trainset = FeasDataset(totfeas[alidx,:], totlabel[alidx], std)
    else:
        trainset = get_dataset(args, transform_train, index = alidx, train = True )
    
    if len(trainset) < args.batchsize_train:
        Droplast = False
    else:
        Droplast = True
    train_loader = torch.utils.data.DataLoader(
        trainset,
        batch_size = args.batchsize_train,
        num_workers = args.num_workers,
        shuffle=True,
        drop_last = Droplast
    )
    
    if (args.load_al_weight and alitr == 0) or (not args.load_al_weight):
        if args.classifier_type == 'MLP':
            classifier = MLP(indim_classifier, hiddim_classifier, outdim_classifier)
        elif args.classifier_type == 'Linear':
            classifier = torch.nn.Linear(indim_classifier,outdim_classifier)
        else:
            raise NotImplementedError
    

    classifier.cuda()
    if args.distributed_training:
        classifier = torch.nn.DataParallel(classifier)
    
    if args.training_mode != 0:
        model = get_network(args)
        model.load_state_dict(encoder_dict)
        model.fc = torch.nn.Identity() 
        model.cuda()
        if args.distributed_training:
            model = torch.nn.DataParallel(model)
    
    print('point 2 model load', time.time() - s)
    s = time.time()
    
    ### training
    if sampling_strategy == 'learningloss':
        if args.training_mode == 0: 
            lossmodel = LossNet(num_channels=[indim_classifier, hiddim_classifier[0]], interm_dim=128)
            lossmodel = lossmodel.cuda().train()
            classifier, lossmodel, trainloss = train_mlp_wlearningloss(train_loader, classifier, lossmodel, args)
            
            # torch.save({'epoch': 100, 'state_dict': classifier.state_dict()}, os.path.join(outpath, 'classifier' + str(len(alidx)) + '.pth'))
            torch.save({'epoch': args.train_eps, 'state_dict': classifier.state_dict()}, os.path.join(outpath, 'classifier.pth'))
    else:
        if args.training_mode == 0: 
            classifier, trainloss = train_mlp(train_loader, classifier, args)
            # torch.save({'epoch': 100, 'state_dict': classifier.state_dict()}, os.path.join(outpath, 'classifier' + str(len(alidx)) + '.pth'))
            torch.save({'epoch': args.train_eps, 'state_dict': classifier.state_dict()}, os.path.join(outpath, 'classifier.pth'))
        elif args.training_mode == 1:
            classifier, trainloss = train_freeze_mlp(train_loader, model, classifier, args)
            torch.save({'epoch': args.train_eps, 'classifier_state_dict': classifier.state_dict()}, os.path.join(outpath, 'checkpoint_' + str(len(alidx)) + '_.pth.tar'))
        elif args.training_mode == 2: 
            model, classifier, trainloss = train_fine_tune(train_loader, model, classifier, args)
            torch.save({'epoch': args.train_eps, 'classifier_state_dict': classifier.state_dict(), 'model_state_dict': model.state_dict()}, os.path.join(outpath, 'checkpoint_' + str(len(alidx)) + '_.pth.tar'))
        elif args.training_mode == 3:
            ### LP stage
            args.cls_lr = args.freezelr
            args.train_eps = args.freeze_eps
            classifier, trainloss = train_freeze_mlp(train_loader, model, classifier, args)# train_mlp(train_loader, classifier, args)
            print('trainloss freeze lp ', trainloss)
            tacc = evaluation(test_loader, classifier, model = model)
            torch.save({'acc': tacc, 'classifier_state_dict': classifier.state_dict()}, os.path.join(outpath, 'classifier_' + str(len(alidx)) + '.pth'))
            
            # FT stage
            args.lr = ftlr
            args.cls_lr = clslr
            args.train_eps = args.ft_eps
            model, classifier, trainloss = train_fine_tune(train_loader, model, classifier, args)
            print('trainloss ft ', trainloss)
            torch.save({'epoch': args.train_eps, 'classifier_state_dict': classifier.state_dict(), 'model_state_dict': model.state_dict()}, os.path.join(outpath, 'checkpoint_' + str(len(alidx)) + '_.pth.tar'))
            
        else:
            raise NotImplementedError
    
    print('point 3 training', time.time() - s)
    s = time.time()
    
    ### evaluation
    acc = evaluation(test_loader, classifier, model = model)
    tacc = evaluation(train_loader, classifier, model = model)
    
    print('point 4 evaluation', time.time() - s)
    s = time.time()
    
    #np.save(outpath + 'totpre' + str(len(alidx)) + '.npy', totpre)
    #np.save(outpath + 'alfeas' + str(len(alidx)) + '.npy', totemb)
    totacc += [acc]
    tracc += [tacc]
    print('AL lblset size is ', len(alidx), 'time ', time.time() - s)
    s = time.time()
    print('test acc: ', acc)
    print('train acc: ', tacc)
    np.save(os.path.join(outpath, 'acc.npy'), np.array(totacc))

### save
np.save(os.path.join(outpath, 'acc.npy'), np.array(totacc))

print('total time:', time.time() - s0)
