# -*- coding: utf-8 -*-
"""
Created on Mon Feb 27 13:18:32 2023

@author: cvpr2024 11221
"""

import numpy as np
import os
import torch

from dataset_model import FeasDataset, ImageFolderWithIndex, MLP, get_augmentation, get_dataset, get_network
from utils import train_mlp, evaluation, get_output_emb, train_freeze_mlp, train_fine_tune

import argparse
import shutil

parser = argparse.ArgumentParser(description='')
parser.add_argument('--sampling_strategy', default='margin', type=str,
                    help='Sampling strategy')
parser.add_argument('--al_budget', default='[600]+[200]*7', type=str,#[10000]*10
                    help='dims of classifier')
parser.add_argument('--expid', default=None, type=str,
                    help='order of exps')
parser.add_argument('--outpath_base', default='./res/', type=str,
                    help='path of results')
parser.add_argument('--dataset_name', default=None, type=str,
                    help='name of dataset [imagenet, feas, cifar10, imagenet100, vic_cape_howe]')
parser.add_argument('--dataset_path', default=None, type=str,
                    help='for fine-tune and freezing & mlp')
parser.add_argument('--selfmodel_path', default=None, type=str,
                    help='path of selfsup model')
parser.add_argument('--trainidx', default=None, type=str,
                    help='trainidx for vic_cape_howe dataset')
parser.add_argument('--testidx', default=None, type=str,
                    help='testidx for vic_cape_howe dataset')

parser.add_argument('--load_proj_weight', default=False, type=bool,
                    help='initialized classifier weights from projector')
parser.add_argument('--load_al_weight', default=False, type=bool,
                    help='initialized classifier weights from last Active learning round')

parser.add_argument('--train_eps', default=200, type=int,
                    help='# of training epoch')
parser.add_argument('--lr', default=0.0001, type=float,
                    help='learning rate')
parser.add_argument('--cls_lr', default=0.1, type=float,
                    help='learning rate for classifier')
parser.add_argument('--momentum', default=0.9, type=float,
                    help='momentum')
parser.add_argument('--weight_decay', default=0, type=float,
                    help='weight_decay')
parser.add_argument('--nesterov', default=True, type=bool,
                    help='nesterov')
parser.add_argument('--milestone', default='60, 80', type=str,
                    help='learning rate schedule (when to drop lr by a ratio)')
parser.add_argument('--early_stop', default=100, type=int,
                    help='efficient AL baseline, early stop')
parser.add_argument('--freezelr', default=10, type=float,
                    help='lr in freeze stage')
parser.add_argument('--freeze_eps', default=80, type=int,
                    help='training eps of lp stage')
parser.add_argument('--ft_eps', default=40, type=int,
                    help='training eps of ft stage')

parser.add_argument('--network', default='res18', type=str,
                    help='[res18,res50,res50x2,res50x4]')

parser.add_argument('--batchsize_train', default=256, type=int,
                    help='path of testset label')
parser.add_argument('--grad_accu', default=1, type=int,
                    help='num grad accum')
parser.add_argument('--batchsize_al_forward', default=512, type=int,
                    help='path of testset label')
parser.add_argument('--batchsize_evaluation', default=512, type=int,
                    help='path of testset label')
parser.add_argument('--classifier_dim', default='512,512,10', type=str,
                    help='dims of classifier')

parser.add_argument('--training_mode', default=3, type=int,
                    help='0:MLP_proxy(ours), 1:freezing encoder and training classifier, 2:Fine-tuning, 3:LP-FT')
parser.add_argument('--classifier_type', default='Linear', type=str,
                    help='Linear or MLP')

parser.add_argument('--distributed_training', default=False, type=bool,
                    help='using nn.dataparaller')

parser.add_argument('--alidx_name', default='alidx.npy', type=str,
                    help='path of selfsup model')
parser.add_argument('--mlpproxy_expid', default='1_r50byoleman_mlpproxy', type=str,
                    help='order of exps')
parser.add_argument('--mlpproxy_dataset', default='feas', type=str,
                    help='order of exps')
parser.add_argument('--mlpproxy_trainmode', default='_training_strategy0', type=str,
                    help='order of exps')
parser.add_argument('--alidxpath', default=None, type=str,#
                    help='another choice to input alidx')


args = parser.parse_args()
if args.training_mode == 3:
    ftlr = args.lr
    clslr = args.cls_lr
    
if args.trainidx is not None:
    args.trainidx = np.load(args.trainidx)
if args.testidx is not None:
    args.testidx = np.load(args.testidx)


args.milestone = args.milestone.split(',')
args.milestone = [int(i) for i in args.milestone]


print(args.lr)
print(args.cls_lr)
print(args.expid)

indim_classifier, hiddim_classifier, outdim_classifier = args.classifier_dim.split(',')
indim_classifier, hiddim_classifier, outdim_classifier = int(indim_classifier), [int(hiddim_classifier)], int(outdim_classifier)

num_budget = eval(args.al_budget)
num_al_itr = len(num_budget)


sampling_strategy = args.sampling_strategy
expid = args.expid


dataset = args.dataset_name

outpath = os.path.join(args.outpath_base, dataset) 
exp_name = dataset + '_' + sampling_strategy + '_exp' + str(expid) + '_training_strategy' + str(args.training_mode)
outpath = os.path.join(outpath, exp_name)
os.makedirs(outpath, exist_ok=True) 


if args.alidxpath is None:
    alidx_path = os.path.join(args.outpath_base, args.mlpproxy_dataset, 
                              args.mlpproxy_dataset + '_' +sampling_strategy + '_exp' + args.mlpproxy_expid + args.mlpproxy_trainmode,
                              args.alidx_name)
else:
    alidx_path = args.alidxpath

hyperalidx = np.load(alidx_path)

#record configuration file
shutil.copy(os.path.join('.','lpft_mlpproxy.py'), outpath)

selfmodel_path = args.selfmodel_path

transform_train = get_augmentation(args, train = True)
transform_test = get_augmentation(args, train = False)

testset = get_dataset(args, transform_test, index = None, train = False )
test_loader = torch.utils.data.DataLoader(
    testset,
    batch_size = args.batchsize_evaluation,
    num_workers = 8,
    shuffle = False,
    drop_last = False
)

totacc = []
tracc = []


if args.classifier_type == 'MLP':
    classifier = MLP(indim_classifier, hiddim_classifier, outdim_classifier)
elif args.classifier_type == 'Linear':
    classifier = torch.nn.Linear(indim_classifier,outdim_classifier)
else:
    raise NotImplementedError
    
###load model and initiliaze with self-sup weight 
checkpoint = torch.load(selfmodel_path, map_location=torch.device('cpu'))

model = get_network(args)

encoder_dict = model.state_dict()
if args.network == 'res50':
    #state_dict = {k[7:]:v for k,v in checkpoint['online_backbone'].items() if k[7:] in encoder_dict.keys()}#byol
    state_dict = {k[27:]:v for k,v in checkpoint['state_dict'].items() if k[27:] in encoder_dict.keys()}#byol eman
elif args.network == 'res18':
    state_dict = {k[9:]:v for k,v in checkpoint['state_dict'].items() if k[9:] in encoder_dict.keys()} 
else:
    raise NotImplementedError
encoder_dict.update(state_dict)
model.load_state_dict(encoder_dict)

model.fc = torch.nn.Identity()  


import time

s = time.time()
s0 = time.time()

alidx = []
for alitr in range(num_al_itr):
    
    if (alitr == 0 and len(alidx) == 0) or (alitr > 0 and len(alidx) > 0):#not resume al 

        alidx = hyperalidx[:np.sum(num_budget[:alitr+1])]
        np.save(os.path.join(outpath, 'alidx1.npy'), np.array(alidx))
    
    trainset = get_dataset(args, transform_train, index = alidx, train = True )
    train_loader = torch.utils.data.DataLoader(
        trainset,
        batch_size = args.batchsize_train,
        num_workers = 8,#args.num_workers,
        shuffle=True,
        drop_last=True
        
    )
    
    if (args.load_al_weight and alitr == 0) or (not args.load_al_weight):
        if args.classifier_type == 'MLP':
            classifier = MLP(indim_classifier, hiddim_classifier, outdim_classifier)
        elif args.classifier_type == 'Linear':
            classifier = torch.nn.Linear(indim_classifier,outdim_classifier)
        else:
            raise NotImplementedError
    
    
    classifier.cuda()
    if args.distributed_training:
        classifier = torch.nn.DataParallel(classifier)
    
    model = get_network(args)
    model.load_state_dict(encoder_dict)
    model.fc = torch.nn.Identity() 
    model.cuda()
    if args.distributed_training:
        model = torch.nn.DataParallel(model)
    
    print('point 2 model load', time.time() - s)
    s = time.time()
    
    ### training
    if args.training_mode == 1:
        classifier, trainloss = train_freeze_mlp(train_loader, model, classifier, args)
        torch.save({'epoch': args.train_eps, 'classifier_state_dict': classifier.state_dict()}, os.path.join(outpath, 'checkpoint_' + str(len(alidx)) + '_.pth.tar'))
    elif args.training_mode == 2: 
        model, classifier, trainloss = train_fine_tune(train_loader, model, classifier, args)
        torch.save({'epoch': args.train_eps, 'classifier_state_dict': classifier.state_dict(), 'model_state_dict': model.state_dict()}, os.path.join(outpath, 'checkpoint_' + str(len(alidx)) + '_.pth.tar'))
    elif args.training_mode == 3:
        ### LP stage
        args.cls_lr = args.freezelr
        args.train_eps = args.freeze_eps#80
        classifier, trainloss = train_freeze_mlp(train_loader, model, classifier, args)# train_mlp(train_loader, classifier, args)
        print('trainloss freeze lp ', trainloss)
        tacc = evaluation(test_loader, classifier, model = model)
        torch.save({'acc': tacc, 'classifier_state_dict': classifier.state_dict()}, os.path.join(outpath, 'classifier_' + str(len(alidx)) + '.pth'))
        
        # FT stage
        args.lr = ftlr
        args.cls_lr = clslr
        args.train_eps = args.ft_eps
        model, classifier, trainloss = train_fine_tune(train_loader, model, classifier, args)
        print('trainloss ft ', trainloss)
        torch.save({'epoch': args.train_eps, 'classifier_state_dict': classifier.state_dict(), 'model_state_dict': model.state_dict()}, os.path.join(outpath, 'checkpoint_' + str(len(alidx)) + '_.pth.tar'))
        
    else:
        raise NotImplementedError
    
    print('point 3 training', time.time() - s)
    s = time.time()
    
    ### evaluation
    acc = evaluation(test_loader, classifier, model = model)
    tacc = evaluation(train_loader, classifier, model = model)
    
    print('point 4 evaluation', time.time() - s)
    s = time.time()
    
    totacc += [acc]
    tracc += [tacc]
    print('AL lblset size is ', len(alidx), 'time ', time.time() - s)
    s = time.time()
    print('test acc: ', acc)
    print('train acc: ', tacc)
    np.save(os.path.join(outpath, 'acc.npy'), np.array(totacc))

### save
np.save(os.path.join(outpath, 'acc.npy'), np.array(totacc))

print('total time:', time.time() - s0)
