#####################################################################################
#                                                                                   #
# All the codes about the model constructing should be kept in the folder ./models/ #
# All the codes about the data process should be kept in the folder ./data/         #
# The file ./opts.py stores the options                                             #
# The file ./trainer.py stores the training and test strategy                       #
# The ./main.py should be simple                                                    #
#                                                                                   #
#####################################################################################
import os
import json
import shutil
import torch
import random
import numpy as np
import torch.backends.cudnn as cudnn
from torch.autograd import Variable
from models.model_construct import Model_Construct # for the model construction
from trainer import train # for the training process
from trainer import validate, validate_compute_cen # for the validation/test process
# from trainer import k_means, spherical_k_means, kernel_k_means # for K-means clustering and its variants
from trainer import source_select # for source sample selection
from trainer import get_tar_loss 
from opts import opts # options for the project
from data.prepare_data import generate_dataloader # prepare the data and dataloader
from utils.consensus_loss import ConsensusLoss
import time
# import ipdb
import gc
import threading

args = opts()

gpu_indices = args.gpu_indices
  
torch.cuda.set_device(gpu_indices[0])



def main(args):
    # global red
    
    best_prec1 = 0
    best_test_prec1 = 0
    cond_best_test_prec1 = 0
    best_cluster_acc = 0 
    best_cluster_acc_2 = 0     
    assigned_labels = None
    
    if args.new_rand_method:
        # setting random seed for all of the packages
        np.random.seed(args.seed)
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)
        print(f'The random seed for this experiment is {args.seed}')
    if args.nrmfr:
        raise ValueError('Error!')  # Placeholder for experimental code; should not be executed for the paper
        # cudnn.deterministic = True
        # cudnn.benchmark = False
        
    # define model
    model = Model_Construct(args)
    print(model)
    model = torch.nn.DataParallel(model, device_ids=gpu_indices).cuda() # define multiple GPUs
    
    # define learnable cluster centers
    # learn_cen = Variable(torch.cuda.FloatTensor(args.num_classes, 2048).fill_(0))
    learn_cen = torch.zeros(args.num_classes, 2048, dtype=torch.float32, device='cuda')
    learn_cen.requires_grad_(True)
    # learn_cen_2 = Variable(torch.cuda.FloatTensor(args.num_classes, args.num_neurons * 4).fill_(0))
    learn_cen_2 = torch.zeros(args.num_classes, args.num_neurons * 4, dtype=torch.float32, device='cuda')
    learn_cen_2.requires_grad_(True)
    
    if args.learn_cen_mode == 'rg':
        learn_cen.register_hook(lambda grad: -grad)
        learn_cen_2.register_hook(lambda grad: -grad)
    elif args.learn_cen_mode == 'non_learnable':
        learn_cen.requires_grad_(False)
        learn_cen_2.requires_grad_(False)

    # define loss functions
    # ce loss for calculating loss when evaluating the model (it is not the fidelity loss)
    criterion = torch.nn.CrossEntropyLoss().cuda()
    # define self-supervised loss (Methodology section of the paper)
    criterion_cons = ConsensusLoss(nClass=args.num_classes, div=args.div).cuda()
    # define adaptation loss and the regularization loss together (Methodology section of the paper)
    TarLoss = get_tar_loss(args)
    
    if not args.new_rand_method:
        raise ValueError('Error!')  # Placeholder for experimental code; should not be executed for the paper
        
    # define optimizer
    if args.pretrained:
        param_groups = [
            {'params': model.module.conv1.parameters(), 'name': 'conv'},
            {'params': model.module.bn1.parameters(), 'name': 'conv'},
            {'params': model.module.layer1.parameters(), 'name': 'conv'},
            {'params': model.module.layer2.parameters(), 'name': 'conv'},
            {'params': model.module.layer3.parameters(), 'name': 'conv'},
            {'params': model.module.layer4.parameters(), 'name': 'conv'},
        ]

        # Add classifier layers based on availability
        if hasattr(model.module, 'fc1') and hasattr(model.module, 'fc2'):
            param_groups += [
                {'params': model.module.fc1.parameters(), 'name': 'ca_cl'},
                {'params': model.module.fc2.parameters(), 'name': 'ca_cl'}
            ]
        else:
            param_groups.append({'params': model.module.fc.parameters(), 'name': 'ca_cl'})

    else:
        param_groups = [{'params': model.module.parameters(), 'name': 'conv'}]
    if args.learn_cen_mode == 'non_learnable':
        if args.optimizer == 'SGD':
            optimizer = torch.optim.SGD(param_groups,
                                            lr=args.lr,
                                            momentum=args.momentum,
                                            weight_decay=args.weight_decay, 
                                            nesterov=args.nesterov)
        elif args.optimizer == 'AdamW':
            raise ValueError('Error!')  # Placeholder for experimental code; should not be executed for the paper
    else:    
        optimizer = torch.optim.SGD(param_groups,
                                        lr=args.lr,
                                        momentum=args.momentum,
                                        weight_decay=args.weight_decay, 
                                        nesterov=args.nesterov)
    
    # resume
    epoch = 0                                
    init_state_dict = model.state_dict()
    if args.resume:
        raise ValueError('Error!')  # Placeholder for experimental code; should not be executed for the paper

    # make log directory
    if not os.path.isdir(args.log):
        os.makedirs(args.log)
    log = open(os.path.join(args.log, 'log.txt'), 'a')
    state = {k: v for k, v in args._get_kwargs()}
    log.write(json.dumps(state) + '\n')
    log.close()

    # start time
    log = open(os.path.join(args.log, 'log.txt'), 'a')
    log.write('\n-------------------------------------------\n')
    log.write(time.asctime(time.localtime(time.time())))
    log.write('\n-------------------------------------------\n')
    log.write(f'\nThe random seed for this experiment is {args.seed}')
    log.close()
    if not args.nrmfr:
        cudnn.benchmark = True
        
    # process data and prepare dataloaders
    train_loader_source, train_loader_target, val_loader_target, val_loader_target_t, val_loader_source = generate_dataloader(args)
    train_loader_target.dataset.tgts = list(np.array(torch.LongTensor(train_loader_target.dataset.tgts).fill_(-1))) # avoid using ground truth labels of target
    # print('before changes:', train_loader_target.dataset.tgts)

    print('begin training')
    batch_number = count_epoch_on_large_dataset(train_loader_target, train_loader_source, args)
    print('batch_number:', batch_number)
    num_itern_total = args.max * batch_number

    new_epoch_flag = False # if new epoch, new_epoch_flag=True
    test_flag = False # if test, test_flag=True
    
    # src_cs = torch.cuda.FloatTensor(len(train_loader_source.dataset.tgts)).fill_(1) # initialize source weights
    src_cs = torch.ones(len(train_loader_source.dataset.tgts), dtype=torch.float32, device='cuda')

    count_itern_each_epoch = 0
    # del red
    # gc.collect()
    # torch.cuda.empty_cache()
    for itern in range(epoch * batch_number, num_itern_total):
        # evaluate on the target training and test data
        if (itern == 0) or (count_itern_each_epoch == batch_number):
            prec1, c_s, c_s_2, c_t, c_t_2, c_srctar, c_srctar_2, source_features, source_features_2, source_targets, target_features, target_features_2, target_targets, pseudo_labels = validate_compute_cen(val_loader_target, val_loader_source, model, criterion, epoch, args, compute_cen=args.learn_embed)
            # do not recalculate target acc if tar_t and tar are the same
            if args.tar_t == args.tar:
                test_acc = prec1
            else:
                test_acc = validate(val_loader_target_t, model, criterion, epoch, args)
                
            test_flag = True
            
            # K-means clustering or its variants
            if ((itern == 0) and args.src_cen_first) or (args.initial_cluster == 2):
                cen = c_s
                cen_2 = c_s_2
            else:
                cen = c_t
                cen_2 = c_t_2
                
            # cen = None
            # cen_2 = None
                
            if args.init_cen_on != 's' or args.src_soft_select or args.src_hard_select or args.tar_loss_idx == 0 or args.tar_loss_idx == 2: # all the situations that pseudo-labels are needed
                raise ValueError('Error!')  # Placeholder for experimental code; should not be executed for the paper
            
            
            if args.init_cen_on == 'st':
                raise ValueError('Error!')  # Placeholder for experimental code; should not be executed for the paper

            elif args.init_cen_on == 't':
                raise ValueError('Error!')  # Placeholder for experimental code; should not be executed for the paper

            elif args.init_cen_on == 's':
                cen = c_s
                cen_2 = c_s_2
            else:
                raise()
            
            #if itern == 0:
            learn_cen.data = cen.data.clone()
            learn_cen_2.data = cen_2.data.clone()
            
            # select source samples
            if (itern != 0) and (args.src_soft_select or args.src_hard_select):
                raise ValueError('Error!')  # Placeholder for experimental code; should not be executed for the paper
            
            # use source pre-trained model to extract features for first clustering
            if (itern == 0) and args.src_pretr_first: 
                raise ValueError('Error!')  # Placeholder for experimental code; should not be executed for the paper

            if itern != 0:
                count_itern_each_epoch = 0
                epoch += 1
            # tmp_time = time.time()
            batch_number = count_epoch_on_large_dataset(train_loader_target, train_loader_source, args)
            train_loader_target_batch = enumerate(train_loader_target)
            train_loader_source_batch = enumerate(train_loader_source)
            # print(f'time1 : {time.time() - tmp_time}')
            
            new_epoch_flag = True
   
        if test_flag:
            # record the best prec1 and save checkpoint
            log = open(os.path.join(args.log, 'log.txt'), 'a')
            if prec1 > best_prec1:
                best_prec1 = prec1
                cond_best_test_prec1 = 0
                log.write('\n                                                                                 best val acc till now: %3f' % best_prec1)
            if test_acc > best_test_prec1:
                best_test_prec1 = test_acc
                log.write('\n                                                                                 best test acc till now: %3f' % best_test_prec1)
            is_cond_best = ((prec1 == best_prec1) and (test_acc > cond_best_test_prec1))
            if is_cond_best:
                cond_best_test_prec1 = test_acc
                log.write('\n                                                                                 cond best test acc till now: %3f' % cond_best_test_prec1)
            log.close()

            test_flag = False

        if epoch > args.stop_epoch:
                break

        # train for one iteration
        train_loader_source_batch, train_loader_target_batch = train(train_loader_source, train_loader_source_batch, train_loader_target, train_loader_target_batch, model, learn_cen, learn_cen_2, criterion_cons, optimizer, itern, epoch, new_epoch_flag, src_cs, args, assigned_labels, TarLoss, val_loader_source=val_loader_source)

        model = model.cuda()
        new_epoch_flag = False
        count_itern_each_epoch += 1

    log = open(os.path.join(args.log, 'log.txt'), 'a')
    log.write('\n***   best val acc: %3f   ***' % best_prec1)
    log.write('\n***   best test acc: %3f   ***' % best_test_prec1)
    log.write('\n***   final val acc: %3f   ***' % prec1)
    log.write('\n***   final test acc: %3f   ***' % test_acc)
    log.write('\n***   cond best test acc: %3f   ***' % cond_best_test_prec1)
    # end time
    log.write('\n-------------------------------------------\n')
    log.write(time.asctime(time.localtime(time.time())))
    log.write('\n-------------------------------------------\n')
    log.close()
    if args.save_final:
        save_final(model.state_dict(), args)

def count_epoch_on_large_dataset(train_loader_target, train_loader_source, args):
    batch_number_t = len(train_loader_target)
    batch_number = batch_number_t
    if args.src_cls:
        batch_number_s = len(train_loader_source)
        if batch_number_s > batch_number_t:
            batch_number = batch_number_s
    
    return batch_number


def save_checkpoint(state, is_best, args):
    filename = 'checkpoint.pth.tar'
    dir_save_file = os.path.join(args.log, filename)
    torch.save(state, dir_save_file)
    if is_best:
        shutil.copyfile(dir_save_file, os.path.join(args.log, 'model_best.pth.tar'))

def save_final(model_state_dict, args):
    filename = f'final_seed_{args.seed}.pth'
    dir_save_file = os.path.join(args.log, filename)

    # Combine model state dict and args into one dictionary
    states = {
        'model_state_dict': model_state_dict,
        'args': vars(args)
    }

    torch.save(states, dir_save_file)

def write_true():
    with open('config/output.json', 'w') as file:
        json.dump(True, file)
        
def write_false():
    with open('config/output.json', 'w') as file:
        json.dump(False, file)

def read():
    with open('config/output.json', 'r') as file:
        value = json.load(file)
    return value

if __name__ == '__main__':
    
    main(args)
    
#     if not read():
#         main(args)
#         write_true()
#         time.sleep(1)        
    # while not read():
    #     main(opts())
    #     write_true()
    #     time.sleep(1)
        
    # del a
    # gc.collect()
    # torch.cuda.empty_cache()
    # write_true()
    # time.sleep(4.2)

