import src.utils as utils
import time
import torch.utils.data as data
from torch import optim
import torch
from torchvision import transforms
from src.my_select import *
import os
from dataset_select import dataset_select
import itertools
import torch.optim.lr_scheduler as lr_scheduler
import argparse

from src.utils import WarmUpLR
from src.utils.get_log import log_recorder


"""
Training parameters:
lr_list: The list of learning rate.
device: cuda or cpu.
batch_size: batchsize of the training process.
batch_size_tset: batchsize of the testing process.
EPOCH: lraining rounds.
num_workers: Threads.
dataset_names: Dataset name.
classifier_names: Backbone Network.
dimension: Output feature dimension (modified with network parameters).

Selection parameters:
select_ratios：The proportion of training samples in each cycle.
add_ratios：The proportion of adding samples in each cycle.
method_names: The methods of selecting samples.
select_type: The type of selecting samples(goodset or badset).
"""
parser = argparse.ArgumentParser()
parser.add_argument('--lr_list', type = list, default=[0.01,0.01,0.01])
parser.add_argument('--device', type=str, default='cuda:1')
parser.add_argument('--batch_size', type=int, default=256)
parser.add_argument('--batch_size_tset', type=int, default=256)
parser.add_argument('--EPOCH', type=int, default=100)
parser.add_argument('--warm', type=int, default=1, help='warm up training phase')
parser.add_argument('--num_workers', type=int, default=8, help='how many subprocesses to use for data loading')
parser.add_argument('--dataset_names', type = list, default = ['your dataset']) 
parser.add_argument('--classifier_names', type = list, default = ['ResNet18','Senet18','Mobilenet','VGG11','inception']) 
parser.add_argument('--dimension', type=int, default=512)
parser.add_argument('--select_ratios', type=list, default=[1])#[0.4,0.5,0.6,0.7,0.75,0.80,0.85,0.90,0.95,1]  [0.05,0.1,0.15,0.2,0.25,0.3,0.35,0.4]
parser.add_argument('--add_ratios', type=list, default  = [0])#[0.1,0.1,0.1,0.05,0.05,0.05,0.05,0.05,0.05,0]  [0.05,0.05,0.05,0.05,0.05,0.05,0.05,0]
parser.add_argument('--method_names', type=list, default=['random_select'])
parser.add_argument('--select_type', type=list, default=['ADD-GOOD'],help='ADD-GOOD or ADD-BAD')

args = parser.parse_args()

EPOCH = args.EPOCH
# DEVICE = torch.device(args.device if torch.cuda.is_available() else "cpu")
DEVICE = 'cuda:1'
train_parser = argparse.ArgumentParser()
train_parser.add_argument('--dataset_name', type=str)
train_parser.add_argument('--classifier_name', type=str)
train_parser.add_argument('--select_strategy', type=str)
train_parser.add_argument('--lr', type=float)
train_parser.add_argument('--select_type', type=str)
train_parser.add_argument('--select_ratio', type=int)
train_parser.add_argument('--URFEAL', type=bool, default=False,help='Unsupervised Redundant Feature Elimination')
train_parser.add_argument('--R', type=float, default=1,help='Parameter R of URFEAL')
train_parser.add_argument('--num_train_set', type=int, default=50000)
train_parser.add_argument('--begin_select_ratio', type=int, default=args.select_ratios[0])
train_args = train_parser.parse_args()

if __name__ == '__main__':
    select_data = []
    j = 0
    rq = time.strftime('%Y%m%d%H%M%S', time.localtime(time.time()))
    txt_path = './Result/{}/'.format(rq[:8])
    txtfile = txt_path + '{}.txt'.format(rq[8:])
    if os.path.exists(txt_path):
        pass
    else:
        os.makedirs(txt_path)

    for i in itertools.product(args.dataset_names, args.classifier_names, args.method_names,  args.lr_list, args.select_type, args.select_ratios):

        [train_args.dataset_name, train_args.classifier_name, train_args.select_strategy, train_args.lr, train_args.select_type,train_args.select_ratio] = i
        [num_input,num_classes,file_Path,train_name,test_name] = dataset_select(train_args.dataset_name)
        if j == len(args.add_ratios):
            j = 0
            select_data = []
        add_ratio = args.add_ratios[j]
        
        if train_args.dataset_name == 'mini-Imagenet-10'or train_args.dataset_name == 'CSEv1':
            transform = {
                "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                             transforms.RandomHorizontalFlip(),
                                             transforms.Resize([224, 224]),
                                             transforms.ToTensor(),
                                             transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
                "test": transforms.Compose([transforms.Resize([224,224]),
                                            transforms.ToTensor(),
                                            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}
        elif train_args.dataset_name in ['Cifar10','Cifar10_entropy','Cifar10_Senet_entropy','Cifar10_VGG_entropy','Cifar10_inception_entropy','Cifar10_Mobilenet_entropy']:
            transform = {
                "train": transforms.Compose([transforms.RandomCrop(32, padding=4),
                                             transforms.RandomHorizontalFlip(),
                                             transforms.ToTensor(),
                                             transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
                "test": transforms.Compose([transforms.ToTensor(),
                                            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])}
            transform = {
                "train": transforms.Compose([transforms.RandomCrop(32, padding=4),
                                             transforms.RandomHorizontalFlip(),
                                             transforms.RandomRotation(15),
                                             transforms.ToTensor(),
                                             transforms.Normalize([0.5070751592371323, 0.48654887331495095, 0.4409178433670343], [0.2673342858792401, 0.2564384629170883, 0.27615047132568404])]),
                "test": transforms.Compose([transforms.ToTensor(),
                                            transforms.Normalize([0.5070751592371323, 0.48654887331495095, 0.4409178433670343], [0.2673342858792401, 0.2564384629170883, 0.27615047132568404])])}
        else:
            transform = {
                "train": transforms.Compose([transforms.RandomResizedCrop(220),
                                             transforms.RandomHorizontalFlip(),
                                             transforms.Resize([220, 220]),
                                             transforms.ToTensor(),
                                             transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
                "test": transforms.Compose([transforms.Resize([220,220]),
                                            transforms.ToTensor(),
                                            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}

        if train_args.select_ratio == args.select_ratios[0]:
            base_csv_dir = './Selcetion/{}/{}/{}/{}/'.format(
              train_args.dataset_name, train_args.classifier_name, train_args.select_strategy, train_args.select_type,train_args.select_ratio)
            if os.path.exists(base_csv_dir):
                pass
            else:
                os.makedirs(base_csv_dir)
            base_csv_path = base_csv_dir + 'base_{}.csv'.format(train_args.select_ratio)
            train_set = utils.dataset.TrainDataset(train_args.select_ratio, base_csv_path=base_csv_path, file_name_list=None, unlabeled_list=None, file_path=file_Path, file_name=train_name, num_classes=num_classes, select_data=None, transform=transform['train'])
        else:
            train_set = utils.dataset.TrainDataset(train_args.select_ratio, None, train_set.file_name_list, unlabeled_set.file_name_list, file_Path, train_name, num_classes, select_data=select_data, transform=transform['train'])
        train_loader = data.DataLoader(
            dataset=train_set, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)

        test_set = utils.dataset.TestDataset(file_Path, test_name, transform['test'])
        test_loader = data.DataLoader(
            dataset=test_set, batch_size=args.batch_size_tset, shuffle=True, num_workers=args.num_workers)

        if train_args.select_ratio != args.select_ratios[-1] and train_args.select_ratio != 1:
            unlabeled_set = utils.dataset.Pool(file_Path, train_set.unlabeled_list, transform['test'])
            unlabeled_loader = data.DataLoader(
                dataset=unlabeled_set, batch_size=args.batch_size_tset, shuffle=True, num_workers=args.num_workers)
        else:
            unlabeled_set = [[]]

        classifier = utils.model_select(train_args.classifier_name, num_input, num_classes).to(DEVICE)
        # classifier = torch.nn.DataParallel(classifier.to(DEVICE))
        optimizer = optim.SGD(classifier.parameters(), lr=train_args.lr, momentum=0.9,weight_decay=5e-4)  
        decay_epoch = [30, 60, 80]
        scheduler = lr_scheduler.MultiStepLR(optimizer,milestones=decay_epoch, gamma=0.5)
        recoder = log_recorder(train_args.dataset_name, train_args.classifier_name, train_args.select_strategy, args.batch_size, EPOCH, train_args.lr,train_args.select_type,train_args.select_ratio)
        warmup_scheduler = WarmUpLR(optimizer, len(train_loader) * args.warm)

        top_acc = 0
        max_epoch = 0
        currenttime = time.asctime(time.localtime(time.time()))


        print('With the {} model on the {} dataset, using the {} sample selection method (selection type: {}): training set: {} samples, test set: {} samples, unlabeled set: {} samples. Training begins!'.format(train_args.classifier_name, train_args.dataset_name, train_args.select_strategy, train_args.select_type, len(train_set), len(test_set), len(unlabeled_set)))
        with open(txtfile, "a") as myfile:
            myfile.write('With the {} model on the {} dataset, using the {} sample selection method (selection type: {}): training set: {} samples, test set: {} samples, unlabeled set: {} samples. Training begins!\n'.format(train_args.classifier_name, train_args.dataset_name, train_args.select_strategy, train_args.select_type, len(train_set), len(test_set), len(unlabeled_set)))
        for i in range(EPOCH):
            # print('EPOCH:', i + 1)
            train_iter = iter(train_loader)

            test_iter = iter(test_loader)

            train_loss, train_acc= utils.train(i, warmup_scheduler, classifier, train_args.classifier_name, DEVICE, train_iter, optimizer, train_set, args.batch_size,
                                                      recoder, args.dimension, num_classes, i, EPOCH)
            test_loss, test_acc= utils.test(classifier, train_args.classifier_name, DEVICE, test_iter, test_set, args.batch_size_tset, recoder, True, True)

            print('Epoch [{}/{}] Train Accuracy: {}% Test Accuracy: {}% '.format(i + 1, args.EPOCH,train_acc*100,test_acc*100))
            if i > args.warm:
                scheduler.step()

            if test_acc > top_acc:
                top_acc = test_acc
                max_epoch = i + 1
            if i+1 == EPOCH:
                model_path = './Model/{}/{}/{}/{}_{}_{}_{}/'.format(train_args.dataset_name, train_args.classifier_name, train_args.select_strategy, train_args.lr,
                                                                    currenttime, train_args.select_type, train_args.select_ratio)
                if os.path.exists(model_path):
                    pass
                else:
                    os.makedirs(model_path)
                torch.save(classifier.state_dict(), model_path + 'test_acc_{:.6f}_epoch_{}'.format(test_acc, i + 1),
                           _use_new_zipfile_serialization=False)
                print('SelectRatio [{}/100]completed max_epoch: {} top_acc: {}%'.format(train_args.select_ratio * 100, max_epoch,
                                                                                  top_acc * 100))
                with open(txtfile, "a") as myfile:
                    myfile.write(
                        'SelectRatio [{}/100]completed max_epoch: {} top_acc: {}% \n'.format(train_args.select_ratio * 100, max_epoch,
                                                                                       top_acc * 100))
                recoder.log_close()

        train_iter2 = iter(train_loader)
        if train_args.select_ratio != args.select_ratios[-1] and train_args.select_ratio != 1:
            test_loss, test_acc, labeled_per_embeddinglist, labeled_embeddinglist, labeled_protolist, labeled_embed_maxlist, _, _, _, _, labeled_target \
                = utils.extract(DEVICE, classifier, args.dimension, train_args.classifier_name, num_classes,
                                    DEVICE, train_iter2, train_set, args.batch_size,
                                    recoder, True, True)
            unlabeled_iter = iter(unlabeled_loader)
            unlabeled_loss, unlabeled_acc, _,unlabeled_embeddinglist,_,_, unlabeled_outputlist, unlabeled_losslist,unlabeled_img_path, unlabeled_originlabels, unlabeled_target\
                = utils.extract(DEVICE, classifier, args.dimension, train_args.classifier_name, num_classes, DEVICE, unlabeled_iter, unlabeled_set, args.batch_size_tset, recoder, False, True)

            if train_args.select_strategy == 'metric_select':

                select_data = metric_select(train_args.select_strategy, unlabeled_originlabels,num_classes, train_args.classifier_name, train_args.dataset_name, unlabeled_embeddinglist,
                              unlabeled_img_path,
                              labeled_protolist,  unlabeled_target, train_args.select_ratio,train_args.select_type,add_ratio)
            if train_args.select_strategy == 'random_select':

                select_data = random_select(
                        method_name=train_args.select_strategy,
                        origin_labels=unlabeled_originlabels,
                        classifier_name=train_args.classifier_name,
                        dataset_name=train_args.dataset_name,
                        img_names=unlabeled_img_path,
                        select_ratio=train_args.select_ratio,
                        select_type=train_args.select_type,
                        add_ratio=add_ratio  # args.add_ratios[j]
                    )
            elif train_args.select_strategy == 'distance_entropy':

                distance_entropy_series = DistanceEntropySeries(train_args, unlabeled_originlabels , labeled_embeddinglist, unlabeled_embeddinglist, unlabeled_img_path, add_ratio,labeled_protolist,num_classes,unlabeled_target)

                select_data = distance_entropy_series.distance_entropy_select()

            elif train_args.select_strategy == 'distance_entropy_with_feature_merge':

                distance_entropy_series = DistanceEntropySeries(train_args, unlabeled_originlabels,
                                                                labeled_embeddinglist, unlabeled_embeddinglist,
                                                                unlabeled_img_path, add_ratio, labeled_protolist,
                                                                num_classes, unlabeled_target)
                select_data = distance_entropy_series.distance_entropy_with_merge()
            elif train_args.select_strategy == 'distance_entropy_aver':

                distance_entropy_series = DistanceEntropySeries(train_args, unlabeled_originlabels,
                                                                labeled_embeddinglist, unlabeled_embeddinglist,
                                                                unlabeled_img_path, add_ratio, labeled_protolist,
                                                                num_classes, unlabeled_target)
                select_data = distance_entropy_series.distance_entropy_aver()

            elif train_args.select_strategy == 'DistanceEntropy_v2':
                distance_entropy_series = DistanceEntropySeries(train_args, unlabeled_originlabels , labeled_embeddinglist, unlabeled_embeddinglist, unlabeled_img_path, add_ratio,labeled_protolist,num_classes,unlabeled_target)
                if train_args.select_ratio < 0.15:
                    select_data = distance_entropy_series.distance_entropy_v2(mutual_inf=True)
                else:
                    select_data = distance_entropy_series.distance_entropy_v2(mutual_inf=False)
            elif train_args.select_strategy == 'EdgeDistanceEntropy':
                distance_entropy_series = DistanceEntropySeries(train_args, unlabeled_originlabels,
                                                                labeled_embeddinglist, unlabeled_embeddinglist,
                                                                unlabeled_img_path, add_ratio, labeled_protolist,
                                                                num_classes, unlabeled_target)
                select_data = distance_entropy_series.edge_distance_entropy()
            elif train_args.select_strategy == 'EdgeDistanceEntropyWithFeatureMerge':
                distance_entropy_series = DistanceEntropySeries(train_args, unlabeled_originlabels,
                                                                labeled_embeddinglist, unlabeled_embeddinglist,
                                                                unlabeled_img_path, add_ratio, labeled_protolist,
                                                                num_classes, unlabeled_target)
                select_data = distance_entropy_series.edge_distance_entropy_with_feature_merge()
            elif train_args.select_strategy == 'Entropy':

                entropy_series = EntropySeries(train_args, unlabeled_originlabels, unlabeled_outputlist, unlabeled_embeddinglist,
                 unlabeled_img_path, add_ratio, num_classes, unlabeled_target)
                select_data = entropy_series.entropy_select()
            elif train_args.select_strategy == 'EntropyWithMerge':

                entropy_series = EntropySeries(train_args, unlabeled_originlabels, unlabeled_outputlist, unlabeled_embeddinglist,
                 unlabeled_img_path, add_ratio, num_classes, unlabeled_target)
                select_data = entropy_series.entropy_select_with_merge(R=0.1)
            elif train_args.select_strategy == 'LearningLoss':

                select_data = learning_loss(train_args.select_strategy, unlabeled_originlabels, train_args.classifier_name, train_args.dataset_name,
                                     unlabeled_outputlist, unlabeled_img_path, train_args.select_ratio, unlabeled_target, train_args.select_type,
                                     add_ratio, num_classes, unlabeled_losslist)
            elif train_args.select_strategy == 'LearningLossWithMerge':
                select_data = learning_loss_with_merge(train_args.select_strategy, unlabeled_originlabels, train_args.classifier_name, train_args.dataset_name,
                                  unlabeled_outputlist, unlabeled_img_path, train_args.select_ratio, unlabeled_target, train_args.select_type,
                                  add_ratio, num_classes, unlabeled_losslist, unlabeled_embeddinglist)
            elif train_args.select_strategy == 'SupErembedding':

                select_data = superembedding_select(train_args.select_strategy, unlabeled_originlabels, train_args.classifier_name, train_args.dataset_name,
                                              unlabeled_embeddinglist, unlabeled_img_path, labeled_embed_maxlist,
                                              train_args.train_args.select_ratio)
            elif train_args.select_strategy =='Random':
                select_data = random_select(train_args.select_strategy, unlabeled_originlabels, train_args.classifier_name, train_args.dataset_name, unlabeled_img_path,  train_args.select_ratio, train_args.select_type, add_ratio)
            elif train_args.select_strategy =='RandomWithMerge':
                select_data = random_select_with_merge(train_args,unlabeled_originlabels, unlabeled_outputlist, unlabeled_img_path, unlabeled_target,unlabeled_embeddinglist,add_ratio,num_classes)
            elif train_args.select_strategy == 'CoreSet':
                coreSet = CoreSet(train_args, unlabeled_originlabels , labeled_embeddinglist, unlabeled_embeddinglist, unlabeled_img_path, add_ratio)
                select_data = coreSet.chose()
            elif train_args.select_strategy == 'Badge':
                badge = BadgeSampling(train_args, unlabeled_originlabels , labeled_embeddinglist, unlabeled_embeddinglist,unlabeled_outputlist, unlabeled_img_path, add_ratio,args.dimension,num_classes,unlabeled_target)
                select_data = badge.query()
            elif train_args.select_strategy == 'MutualInformationEntropy':
                mutualInformationEntropy = MutualInformationEntropy(train_args, unlabeled_originlabels, labeled_embeddinglist, unlabeled_embeddinglist,
                             unlabeled_img_path, labeled_protolist, add_ratio, num_classes)
                select_data = mutualInformationEntropy.chose()
            elif train_args.select_strategy == 'CoreWeightEntropy':
                coreWeightEntropy = CoreWeightEntropy(train_args, labeled_per_embeddinglist, unlabeled_originlabels , unlabeled_embeddinglist, unlabeled_outputlist, unlabeled_img_path, add_ratio,labeled_protolist,num_classes,unlabeled_target)
                select_data = coreWeightEntropy.chose()
            else:
                aa = "No Select"
                print(aa)
                select_data = None
        if select_data is not None:
            # print(select_data)
            csv_path = f'./CSV/{train_args.dataset_name}/{train_args.select_strategy}/{train_args.classifier_name}/select_{train_args.select_ratio}.csv' 
            os.makedirs(os.path.dirname(csv_path), exist_ok=True)
            with open(csv_path, 'w', newline='') as csvfile:
                writer = csv.writer(csvfile)
                for path, label in select_data:
                    writer.writerow([path, label])
        print('***************************************finish***************************************')
        j += 1