# -*- coding: utf-8 -*-
"""

@author: Anonymous Author
"""

import numpy as np
import os

import torch

import torchvision
import torchvision.transforms as transforms
import torch
import torch.nn.functional as F

from subimgnet import ImageNetSubset
from torchvision import models

from utils import train_mlp, evaluation, get_output_emb

import argparse

parser = argparse.ArgumentParser(description='')
parser.add_argument('--sampling_strategy', default='random', type=str,
                    help='Sampling strategy')
parser.add_argument('--expid', default='1', type=str,
                    help='order of exps')
parser.add_argument('--al_budget', default='[200]*5 + [500]*2', type=str,#[10000]*10
                    help='budgets of active learning')
parser.add_argument('--output_path', default='/root/autodl-tmp/ntkal/res/', type=str,
                    help='path of results')

### dataset
parser.add_argument('--datapath', default='/root/autodl-tmp/imagenet/', type=str,
                    help='path of dataset')
parser.add_argument('--datapath_index', default='/root/autodl-tmp/ntkal/imagenet100.txt', type=str,
                    help='path of class index of imagenet-100')

### pre-trained
parser.add_argument('--selfmodel_path', default='/root/autodl-nas/resnet50_byol_imagenet2012.pth.tar', type=str,
                    help='path of self-supervised model')
parser.add_argument('--totfeas_path', default='totfeas.npy', type=str,
                    help='path of selfsup feas')
parser.add_argument('--totlabel_path', default='totlabel.npy', type=str,
                    help='path of trainset label')

### training

parser.add_argument('--train_eps', default=100, type=int,#50 for imagenet, 100 for others
                    help='# of training epoch')
parser.add_argument('--lr', default=0.1, type=float,
                    help='learning rate')
parser.add_argument('--momentum', default=0.9, type=float,
                    help='momentum')
parser.add_argument('--weight_decay', default=0, type=float,
                    help='weight_decay')

parser.add_argument('--batchsize_train', default=512, type=int,
                    help='batchsize at training phase')
parser.add_argument('--batchsize_al_forward', default=512, type=int,
                    help='batchsize at active learning pass')
parser.add_argument('--batchsize_evaluation', default=1024, type=int,
                    help='batchsize at evaluation phase')


args = parser.parse_args()


mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

transform_train = transforms.Compose([
               transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                #transforms.Resize(size = [224,224]),
                transforms.ToTensor(),
                transforms.Normalize(mean, std)])

transform_test = transforms.Compose([
    transforms.CenterCrop((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

class MLP(torch.nn.Module):
    def __init__(self, in_dim, hidden_dim=2048, out_dim=2048):
        super().__init__()
        self.layer1 = torch.nn.Sequential(
            torch.nn.Linear(in_dim, hidden_dim),
            torch.nn.BatchNorm1d(hidden_dim),
            torch.nn.ReLU(inplace=True)
        )
        
        self.layer2 = torch.nn.Sequential(
            torch.nn.Linear(hidden_dim, out_dim)
            # torch.nn.BatchNorm1d(out_dim)
        )
        self.emb = None
        
    def set_layers(self, num_layers):
        self.num_layers = num_layers

    def forward(self, x):
        
        x = self.layer1(x)
        self.emb = x.clone()
        x = self.layer2(x)
        
        return x     

num_budget = eval(args.al_budget)
#num_budget = [200,200,200,200,200,500,500]
num_al_itr = len(num_budget)

sampling_strategy = args.sampling_strategy#'coreset_self'
expid = args.expid#1


if sampling_strategy == 'ntkours_self' or sampling_strategy == 'ntkours_al':
    from samplings.cluster_pruce import generate_cluster_pl, constrain_kmeans
    from samplings.ntk_al import ntk_ker_f0, select_new_sample
    num_min_cluster = 50
elif sampling_strategy == 'coreset' or sampling_strategy == 'coreset_self':
    from samplings.sampling_strategy import acquire_new_sample
elif sampling_strategy == 'probcover':
    from samplings.sampling_strategy import construct_dg, sampling_probcover
    thresh = 0.95
elif sampling_strategy == 'typicluster':
    from samplings.sampling_strategy import sampling_typicluster
elif sampling_strategy == 'lookahead':
    from samplings.look_ahead import ntk_ker_f0, select_new_sample
    import jax
    num_class = 100
elif sampling_strategy == 'badge':    
    from samplings.sampling_strategy import kmeans_plus
    from utils import get_grad_embedding
elif sampling_strategy == 'uncertainty':
    from samplings.sampling_strategy import uncertainty_sampling
elif sampling_strategy == 'entropy':
    from samplings.sampling_strategy import entropy_sampling
    

outpath = args.output_path#'/root/autodl-tmp/ntkal/res/' # + exp name
exp_name = 'imgenet100_' + sampling_strategy + '_exp' + str(expid)
outpath = outpath + exp_name
os.makedirs(outpath, exist_ok=True) 

selfmodel_path = args.selfmodel_path#

# load feas,...
totfeas = np.load(args.totfeas_path)
alidx = []
totlabel = np.load(args.totlabel_path)

datapath = args.datapath#'/root/autodl-tmp/imagenet/'
dataindex = args.datapath_index

allset = ImageNetSubset(dataindex, datapath, 
                          None, split = 'train', transform = transform_test)      
all_loader = torch.utils.data.DataLoader(
    allset,
    batch_size = args.batchsize_al_forward,
    shuffle = False,
    drop_last = False,
    num_workers = 16,
    pin_memory=True
)

testset = ImageNetSubset(dataindex, datapath, 
                          None, split = 'val', transform = transform_test)   

test_loader = torch.utils.data.DataLoader(
    testset,
    batch_size = args.batchsize_evaluation,#1024,
    shuffle = False,
    drop_last = False,
    num_workers = 16,
    pin_memory=True
)

totemb = None
totpre = None
model = models.resnet50()
classifier = MLP(2048,4096,100)
totacc = []
for alitr in range(num_al_itr):
    
    if sampling_strategy == 'ntkours_self' or sampling_strategy == 'ntkours_al':# prepare totpre, totfeas/alfeas, hyperparams
        if len(alidx) == 0:
            num_cluster = num_budget[0]
            cluster,_ = constrain_kmeans(totfeas, [], [], num_cluster)
            np.save(os.path.join(outpath, 'cluster' + str(alitr) + '_0.npy'), cluster)
            
            scaleidx = [i for i in range(len(totfeas))]
            np.random.shuffle(scaleidx)
            scaleidx = scaleidx[:50000]
            np.save(os.path.join(outpath, 'scaleidx_' + str(alitr) + '.npy'), scaleidx)
            
            #cluster1,_ = constrain_kmeans(totfeas, [], [], 10)
            #cluster2,_ = constrain_kmeans(totfeas, [], [], 30)
            #np.save(os.path.join(outpath, 'cluster' + str(alitr) + '_10.npy'), cluster1)
            #np.save(os.path.join(outpath, 'cluster' + str(alitr) + '_30.npy'), cluster2)
            ker, f0 = ntk_ker_f0(totfeas, num_indim = 2048, num_hidden = 4096, num_classes = num_cluster, scaleidx = scaleidx, batch_ker = True)
            alidx = select_new_sample([], [ker], [f0], [cluster[scaleidx]], num_budget[alitr], [num_cluster], outpath = outpath, stepsize = 250)
            #alidx = select_new_sample([], [ker,ker], [f0[:,:10],f0], [cluster1,cluster2], num_budget[alitr], [10,30], outpath = outpath, stepsize = 250)
            
            alidx = (np.array(scaleidx)[alidx]).tolist()
            
        else:
            #prepare for al
            num_cluster = int( (len(alidx) + num_budget[alitr]) / 2)#30#
            if num_cluster > 500:
                num_cluster = 500
                num_min_cluster = 10#num_cluster = 100
            
            totpre, totemb = get_output_emb(all_loader, model, classifier)
            np.save(os.path.join(outpath, 'totpre' + str(alitr) + '.npy'), totpre)   
            if sampling_strategy == 'ntkours_self':
                cluster = generate_cluster_pl(totfeas, alidx, totlabel[alidx], totpre, num_cluster, totlabel, num_level = 14, num_min_cluster = num_min_cluster, num_class = 100)
            elif sampling_strategy == 'ntkours_al':
                cluster = generate_cluster_pl(totemb, alidx, totlabel[alidx], totpre, num_cluster, totlabel, num_level = 14, num_min_cluster = num_min_cluster, num_class = 100)
        
            np.save(os.path.join(outpath, 'cluster' + str(alitr +1) + '_' + str(num_cluster) + '.npy'), cluster)
            ### al
            
            scaleidx = [i for i in range(len(totfeas)) if i not in alidx]
            np.random.shuffle(scaleidx)
            #scaleidx = scaleidx[:50000]
            scaleidx = alidx + scaleidx
            scaleidx = scaleidx[:52000]
            np.save(os.path.join(outpath, 'scaleidx_' + str(alitr) + '.npy'), scaleidx)
            
            ker, f0 = ntk_ker_f0(totfeas, num_indim = 2048, num_hidden = 4096, num_classes = num_cluster, scaleidx = scaleidx, batch_ker = True)
            
            if len(alidx) > 1000:
                stepsize = 100
            else:
                stepsize = 250
            alidx = select_new_sample([i for i in range(len(alidx))], [ker], [f0], [cluster[scaleidx]], num_budget[alitr], [num_cluster], outpath = outpath, stepsize = stepsize)
            
            alidx = (np.array(scaleidx)[alidx]).tolist()
        
    elif sampling_strategy == 'typicluster':
        newidx = sampling_typicluster(totfeas, alidx, num_budget[alitr])
        alidx += newidx
                
    elif sampling_strategy == 'coreset_self':
        if len(alidx) == 0:
            alidx = [np.random.randint(0,len(allset))]
        label_feas = totfeas[alidx,:]
        uidx = [i for i in range(len(allset)) if i not in alidx]
        unlabel_feas = totfeas[uidx,:]
        if len(alidx) == 1:
            newidx = acquire_new_sample(num_budget[alitr]-1, uidx, torch.from_numpy(label_feas), torch.from_numpy(unlabel_feas))
        else:
            newidx = acquire_new_sample(num_budget[alitr], uidx, torch.from_numpy(label_feas), torch.from_numpy(unlabel_feas))
        alidx += newidx
    
    elif sampling_strategy == 'coreset':
        if len(alidx) == 0:
            uidx = [i for i in range(len(allset)) if i not in alidx]
            np.random.shuffle(uidx)
            alidx += uidx[:num_budget[alitr]] 
        else:
            _, alfeas = get_output_emb(all_loader, model, classifier)
            np.save(os.path.join(outpath, 'alfeas' + str(alitr) + '.npy'), alfeas)
            label_feas = alfeas[alidx,:]
            uidx = [i for i in range(len(allset)) if i not in alidx]
            unlabel_feas = alfeas[uidx,:]
            newidx = acquire_new_sample(num_budget[alitr], uidx, torch.from_numpy(label_feas), torch.from_numpy(unlabel_feas))
            alidx += newidx
        
    elif sampling_strategy == 'random':
        uidx = [i for i in range(len(allset)) if i not in alidx]
        np.random.shuffle(uidx)
        alidx += uidx[:num_budget[alitr]]   
    
    elif sampling_strategy == 'lookahead':# prepare fullnet ntk, f0, or ntk for mlp
        if len(alidx) == 0:
            uidx = [i for i in range(len(allset)) if i not in alidx]
            np.random.shuffle(uidx)
            alidx += uidx[:num_budget[alitr]]
            uidx = list(set(uidx) - set(alidx))
            #ker, f0 = ntk_ker_f0(totfeas, num_classes = num_class, batch_ker = True)
            # backend = jax.lib.xla_bridge.get_backend()
            # for buf in backend.live_buffers(): buf.delete()
            # for buf in backend.live_executables(): buf.delete()
        else:
            
            scaleidx = uidx.copy()
            np.random.shuffle(scaleidx)
            scaleidx = scaleidx[:50000]
            np.save(os.path.join(outpath, 'scaleidx_' + str(alitr) + '.npy'), scaleidx)
            ker, f0 = ntk_ker_f0(totfeas, num_indim = 2048, num_hidden = 4096, num_classes = num_class, scaleidx = scaleidx, batch_ker = True)
            
            totpre, _ = get_output_emb(all_loader, model, classifier)
            np.save(os.path.join(outpath, 'totpre' + str(alitr) + '.npy'), totpre)
            newidx = select_new_sample([], [ker], [f0], [totpre[scaleidx,:].argmax(axis=1)], num_budget[alitr], num_class, outpath, stepsize = 250)
            
            alidx += (np.array(scaleidx)[newidx]).tolist()
            uidx = list(set(uidx) - set(alidx))
        
    elif sampling_strategy == 'badge':
        if len(alidx) == 0:
            uidx = [i for i in range(len(allset)) if i not in alidx]
            np.random.shuffle(uidx)
            alidx += uidx[:num_budget[alitr]] 
        else:
            emb = get_grad_embedding(len(allset), 10, 64, model, classifier, all_loader)
            uidx = [i for i in range(len(allset)) if i not in alidx]
            newidx = kmeans_plus(uidx, emb, num_budget[alitr])
            alidx += newidx
    
    elif sampling_strategy == 'uncertainty':
        if len(alidx) == 0:
            uidx = [i for i in range(len(allset)) if i not in alidx]
            np.random.shuffle(uidx)
            alidx += uidx[:num_budget[alitr]] 
        else:
            totpre, _ = get_output_emb(all_loader, model, classifier)
            np.save(os.path.join(outpath, 'totpre' + str(alitr) + '.npy'), totpre)
            uidx = [i for i in range(len(allset)) if i not in alidx]
            alidx += uncertainty_sampling(totpre, uidx, num_budget[alitr])
            
    elif sampling_strategy == 'entropy':
        if len(alidx) == 0:
            uidx = [i for i in range(len(allset)) if i not in alidx]
            np.random.shuffle(uidx)
            alidx += uidx[:num_budget[alitr]] 
        else:
            totpre, _ = get_output_emb(all_loader, model, classifier)
            np.save(os.path.join(outpath, 'totpre' + str(alitr) + '.npy'), totpre)
            uidx = [i for i in range(len(allset)) if i not in alidx]
            alidx += entropy_sampling(totpre, uidx, num_budget[alitr])
        
    else:
        raise Exception("Wrong sampling strategy")
        
    np.save(os.path.join(outpath, 'alidx' + str(alitr) + '.npy'), np.array(alidx))
    
    
    trainset = ImageNetSubset(dataindex, datapath, 
                              alidx, split = 'train', transform = transform_train)
    train_loader = torch.utils.data.DataLoader(
        trainset,
        batch_size = args.batchsize_train,
        shuffle=True,
        num_workers = 16,
        pin_memory=True
    )
    
    ### training
    ###load model and initiliaze with self-sup weight 
    model = models.resnet50()
    checkpoint = torch.load(selfmodel_path, map_location=torch.device('cpu'))
    model_dict = model.state_dict()
    state_dict = {k[7:]:v for k,v in checkpoint['online_backbone'].items() if k[7:] in model_dict.keys()}#byol
    model_dict.update(state_dict)
    model.load_state_dict(model_dict)

    model.fc = torch.nn.Identity()
    
    # load part projector weights
    corr = {}
    corr['layer1.0.weight'] = 'l1.weight'
    corr['layer1.0.bias'] = 'l1.bias'
    corr['layer1.1.bias'] = 'bn1.bias'
    corr['layer1.1.weight'] = 'bn1.weight'
    corr['layer1.1.num_batches_tracked'] = 'bn1.num_batches_tracked'
    corr['layer1.1.running_mean'] = 'bn1.running_mean'
    corr['layer1.1.running_var'] = 'bn1.running_var'
    
    classifier = MLP(2048,4096,100)
    
    model_dict = classifier.state_dict()
    state_dict = {k[7:]:v for k,v in checkpoint['online_projection'].items()}
    state_dict1 = {}
    for ikey in corr:
    	state_dict1[ikey] = state_dict[corr[ikey]]
        
    model_dict.update(state_dict1)
    classifier.load_state_dict(model_dict)    
    
    model.cuda()
    classifier.cuda()
    
    model, classifier = train_mlp(train_loader, model, classifier, args)
    
    torch.save({'epoch': 100, 'state_dict': classifier.state_dict()}, os.path.join(outpath, 'classifier' + str(len(alidx)) + '.pth'))
    ### evaluation
    acc = evaluation(test_loader, model, classifier)
    #np.save(outpath + 'totpre' + str(len(alidx)) + '.npy', totpre)
    #np.save(outpath + 'alfeas' + str(len(alidx)) + '.npy', totemb)
    totacc += [acc]
    print('AL lblset size is ', len(alidx))
    print('test acc: ', acc)
    np.save(os.path.join(outpath, 'acc.npy'), np.array(totacc))

### save
np.save(os.path.join(outpath, 'acc.npy'), np.array(totacc))


###avg
# totacc = []
# for i in range(1,6):
#     t = np.load('C:\\document\\code\\fullcover_al\\ref code\\ntk_al_ours\\mlpntkal\\res\\cifar10_mlp64_coreset_exp' + str(i) + '\\acc.npy')
#     totacc += [t]

# totacc = np.array(totacc)
# print(totacc.mean(axis=0))
