import argparse
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import random
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from data.sampler import SubsetSequentialSampler
from models import ResNet
from GMM_train_CIFAR10 import GMM_train, GMM_Semi_train, MakeClsPreLoader, EXTRACT_GMM_prev
from models.GMM_model_CIFAR100 import *
from models.DA_Critic_model import *
from load_dataset import load_dataset
from Config import *
from data.Datahandler import *
from Selection import Kmeans_Selection
from Selection import *

parser = argparse.ArgumentParser()
parser.add_argument("-l","--lambda_loss",type=float, default=1.2, help="Adjustment graph loss parameter between the labeled and unlabeled")
parser.add_argument("-s","--s_margin", type=float, default=0.1, help="Confidence margin of graph")
parser.add_argument("-n","--hidden_units", type=int, default=128, help="Number of hidden units of the graph")
parser.add_argument("-r","--dropout_rate", type=float, default=0.3, help="Dropout rate of the graph neural network")
parser.add_argument("-d","--dataset", type=str, default="cifar10", help="")
parser.add_argument("-e","--no_of_epochs", type=int, default=200, help="Number of epochs for the active learner")
parser.add_argument("-m","--method_type", type=str, default="DAAL", help="")
parser.add_argument("-c","--cycles", type=int, default=10, help="Number of active learning cycles")
# Model Argument
parser.add_argument('--unsupervised-em-iters', type=int, default=10)
parser.add_argument('--semisupervised-em-iters', type=int, default=10)
parser.add_argument('--fix-pi', type=bool, default=True)
parser.add_argument('--hidden-size', type=int, default=128)
parser.add_argument('--component-size', type=int, default=10)
parser.add_argument('--latent-size', type=int, default=128)
parser.add_argument('--train-mc-sample-size', type=int, default=32)
parser.add_argument('--test-mc-sample-size', type=int, default=32)
parser.add_argument('--gpu-id', type=int, default=1)
parser.add_argument('--fix-var', default=False)
parser.add_argument('--fix-var-test', default=False)
parser.add_argument('--prenum', type=int, default=6000)
parser.add_argument('--clsn', type=int, default=400)
parser.add_argument('--clsm', type=int, default=3)
parser.add_argument('--c-ratio', type=float, default=0.01)
parser.add_argument('--normalize', type=bool, default=True)
parser.add_argument('--versionSave', type=bool, default=True)
parser.add_argument('--saveName', type=str, default="SAVE")
parser.add_argument('--version', type=str, default="SAVE")
parser.add_argument('--imbalanceRatio', type=int, default=10)

args = parser.parse_args()
torch.cuda.set_device(args.gpu_id)
torch.set_num_threads(2)

# Main
if __name__ == '__main__':
    random_seed = 200
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(random_seed)
    random.seed(random_seed)

    method = args.method_type
    print("Dataset: %s"%args.dataset)
    print("Method type:%s"%method)
    CYCLES = args.cycles
    # Load training and testing dataset
    data_train, data_unlabeled, data_test, adden, NO_CLASSES, no_train = load_dataset(args.dataset, ir = args.imbalanceRatio)
    print('The entire datasize is {}'.format(len(data_train)))
    PRENUM, ADDENDUM = args.prenum, adden
    NUM_TRAIN = no_train
    indices = list(range(NUM_TRAIN))
    random.shuffle(indices)

    labeled_set = indices[:ADDENDUM]
    unlabeled_set = [x for x in indices if x not in labeled_set]

    train_loader = DataLoader(data_train, batch_size=BATCH, sampler=SubsetRandomSampler(labeled_set), pin_memory=True, drop_last=True)
    UL_train_loader = DataLoader(data_train, batch_size=BATCH, sampler=SubsetRandomSampler(unlabeled_set), pin_memory=True, drop_last=True)
    L_Summary_loader = DataLoader(data_train, batch_size=SUM_BATCH2, sampler=SubsetRandomSampler(labeled_set), pin_memory=True, drop_last=True)
    UL_Summary_loader = DataLoader(data_train, batch_size=SUM_BATCH2, sampler=SubsetRandomSampler(unlabeled_set), pin_memory=True, drop_last=True)

    test_loader  = DataLoader(data_test, batch_size=SUM_BATCH, pin_memory=True)
    dataloaders  = {'train': train_loader, 'train_UL':UL_train_loader, 'test': test_loader,
                    'summary_L':L_Summary_loader, 'summary_UL':UL_Summary_loader}

    for cycle in range(CYCLES):
        random.shuffle(unlabeled_set)
        GMM_model = Make_GMM_Model(args).cuda()
        GMM_model.latent_size = 512
        Critic = Make_DA_Critic_Model(args, Nc=args.component_size).cuda()
        resnet18 = ResNet.ResNet18(num_classes=NO_CLASSES).cuda()
        models = {'backbone': resnet18, 'critic': Critic}
        criterion      = nn.CrossEntropyLoss(reduction='none')
        ClassPreLoader = MakeClsPreLoader(args.dataset, data_train, labeled_set)

        optim_backbone = optim.SGD(models['backbone'].parameters(), lr=LR, momentum=MOMENTUM, weight_decay=WDECAY)
        sched_backbone = lr_scheduler.MultiStepLR(optim_backbone, milestones=[160, 240])
        optims, scheds = {'backbone': optim_backbone}, {'backbone': sched_backbone}

        GMM_train(models, GMM_model, optims, scheds, dataloaders, args.no_of_epochs, fix_var=args.fix_var,
                        fix_var_test=args.fix_var_test, cycle=1, ClsIdx=ClassPreLoader, normalize=args.normalize)

        optim_backbone = optim.SGD(models['backbone'].parameters(), lr=LR, momentum=MOMENTUM, weight_decay=WDECAY)
        optim_critic = optim.Adam(models['critic'].parameters(), lr=C_LR)
        sched_backbone = lr_scheduler.MultiStepLR(optim_backbone, milestones=[8000, 24000])
        sched_critic = lr_scheduler.MultiStepLR(optim_critic, milestones=[8000, 24000])
        optims = {'backbone': optim_backbone, 'critic': optim_critic}
        scheds = {'backbone': sched_backbone, 'critic': sched_critic}

        GMM_Semi_train(models, GMM_model, method, optims, scheds, dataloaders, EPOCH, fix_var=args.fix_var,
                        fix_var_test=args.fix_var_test, ClsIdx=ClassPreLoader, cycle=1)

        L_ALL_GMM_pi, L_ALL_GMM_mean, L_ALL_GMM_var, Ul_ALL_GMM_pi, Ul_ALL_GMM_mean, Ul_ALL_GMM_var, mean_linear_acc = \
            EXTRACT_GMM_prev(models, GMM_model, dataloaders, fix_var=args.fix_var, fix_var_test=args.fix_var_test,
                             ClsIdx=ClassPreLoader, update_Lpi=True, update_ULpi=True, HintUL = False)

        Phis = (L_ALL_GMM_pi, L_ALL_GMM_mean, L_ALL_GMM_var, Ul_ALL_GMM_pi, Ul_ALL_GMM_mean, Ul_ALL_GMM_var)

        # ========================================= CALCULATE UNCERTAINTY ============================================#
        arg = Likelihood_Sampling_DAAL(data_train, unlabeled_set, BATCH, models, GMM_model,
                            Phis=Phis, clsn=args.clsn, clsm=args.clsm, addnum=ADDENDUM,
                            clsList=True, cycle=cycle, CYCLES=CYCLES, c_ratio=args.c_ratio, Uni=True)
        assert len(set(labeled_set) & set(arg[0].tolist())) == 0
        chosen = Kmeans_Selection(arg, data_train, BATCH, models, ADDENDUM, method = args.method_type, clsm=args.clsm)
        assert len(set(labeled_set) & set(chosen)) == 0

        labeled_set = list(set(labeled_set) | set(chosen))
        unlabeled_set = list(set(unlabeled_set) - set(chosen))
        assert len(set(labeled_set) & set(unlabeled_set)) == 0
        ALL_labels = np.array(data_train.targets)
        newIdx = torch.tensor(labeled_set).numpy()
        print(f'len(X_L):{len(labeled_set)}, len(X_UL):{len(unlabeled_set)}')
        # Create a new dataloader for the updated labeled dataset
        dataloaders['train'] = DataLoader(data_train, batch_size=BATCH, sampler=SubsetRandomSampler(labeled_set))
        dataloaders['train_UL'] = DataLoader(data_train, batch_size=BATCH, sampler=SubsetRandomSampler(unlabeled_set))
        dataloaders['summary_L'] = DataLoader(data_train, batch_size=SUM_BATCH2, sampler=SubsetSequentialSampler(labeled_set), pin_memory=True)
        dataloaders['summary_UL'] = DataLoader(data_train, batch_size=SUM_BATCH2, sampler=SubsetSequentialSampler(unlabeled_set), pin_memory=True)


