import torch
import logging
import argparse
import datetime
from tqdm import tqdm
from utils.util import *
from utils.taps_util import *
from configs.config import *
from utils.dist_util import *
from torch.utils.data import DataLoader
from datasets.word_classification import *
from models.ps.layer_warpper import ModelUnion, SHARE_THRESHOLD
# import datasets.decathlon_datasets as decathlon_datasets

from copy import deepcopy

import torch.nn.parallel
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.backends.cudnn as cudnn
from torch.utils.data.distributed import DistributedSampler

def get_args():
    parser = argparse.ArgumentParser(description='Arguments for the training purpose.')    
    parser.add_argument('--gpuNums', type=int, default=1, help='number of gpus')
    parser.add_argument('--nEpochs', type=int, default=40, help='number of epochs to train for')
    parser.add_argument('--warmup', type=int, default=2, help='the epochs for warmup')
    parser.add_argument('--lr', type=float, default=1e-1, help='Learning Rate. Default=0.1')
    parser.add_argument('--mask_lr', type=float, default=-1.0, help='Mask Learning Rate. Default=0.2')
    parser.add_argument('--optim', type=str, required=False, default="SGD",choices=["ADAM", "SGD", "ADAMW"], help='optimizer. Default=ADAM')
    parser.add_argument('--wd', type=float, required=False, default=0.0, help='weight decay. Default=0.0')
    parser.add_argument('--momentum', type=float, required=False, default=0.9, help='momentum. Default=0.9')
    parser.add_argument('--threads', type=int, default=12, help='number of threads for data loader to use')
    parser.add_argument('--backbone', type=str, required=False, default='resnet50',choices=[
        "vit_small_patch16_224", 
        "vit_base_patch16_224", 
        "resnet18", 
        "resnet34", 
        "resnet50", 
        "wide_resnet", 
        "timm_resnet18",
        "timm_resnet26", 
        "timm_resnet34", 
        "timm_resnet50", 
        "densenet121", 
        "timm_densenet121",
        ], help="backbone of the model")
    parser.add_argument('--batchSize', type=int, default=96, help='training batch size')
    parser.add_argument('--dataset_name', type=str, required=True, help="which dataset to train")
    parser.add_argument('--resume_from', type=int, default=0, help='iteration to resume from')
    parser.add_argument('--save_path', type=str, default="chk/exp", help='path to save the model')
    parser.add_argument('--visual_file', type=str, default="", help='path to save the visual_data')
    parser.add_argument('--logname', type=str, default='ps_joint_log', help="name of the logging file")
    parser.add_argument('--chkname', type=str, default='chk/torch/resnet50-19c8e357.pth', help="name of the checkpoints folder")
    parser.add_argument('--p', type=float, required=False, default=0.5, help='end p. Default=0.5')
    parser.add_argument('--p_T', type=int, required=False, default=10, help='the update T of p. Default=10 epochs')
    parser.add_argument('--cropped', type=bool, required=False, default=False, help='crop the pic or not')
    parser.add_argument('--num_iterations', type=int, required=False, default=5, help='the iteration times of the all tasks')

    # DDP settings
    parser.add_argument('--nprocs', type=int, default=1, help='number of gpus')
    parser.add_argument('--local_rank',
                    default=-1,
                    type=int,
                    help='node rank for distributed training')
    parser.add_argument('--seed',
                        default=None,
                        type=int,
                        help='seed for initializing training. ')
    parser.add_argument('--ip', default='127.0.0.1', type=str)
    parser.add_argument('--port', default="29500", type=str)
    
    args = parser.parse_args()
    return args

def train_one_epoch(args, epoch, model, opt, lr_schedular, train_loader):
    model.train()
    losses = []
    for idx, (img, (label, task)) in enumerate(train_loader):
        # print(label)
        img, label = img.cuda(args.local_rank), label.cuda(args.local_rank)
        # print(img.shape, label)
        pre = model(img)
        loss = model.module.losses(pre, label)
        loss = sum([l for k, l in loss.items() if "loss" in k])
        loss.backward()
        # print([x.grad for x in opt.param_groups[0]['params']])
        reduce_loss = reduce_mean(loss, args.nprocs)
        losses.append(reduce_loss.detach().cpu())
        opt.step()
        opt.zero_grad()
        
    lr_schedular.step()
    if args.local_rank == 0:
        lrs = lr_schedular.get_last_lr()
        args.logging.info(f'[{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")} TRAINING]' + \
            f' EPOCH {epoch} model_lr {lrs[0]} ps_lr {lrs[1] if len(lrs) > 1 else 0.0} mask_lr {lrs[2] if len(lrs) > 2 else 0.0}, loss: {mean(losses)}')
    
def test_one_epoch(args, model, test_loader):
    model.eval()
    accuarcies = []
    with torch.no_grad():
        for idx, (img, (label, task)) in enumerate(test_loader):
            img, label = img.cuda(args.local_rank), label.cuda(args.local_rank)
            pre = model(img)
            acc = accuarcy(pre, label)[0]
            # print(label)
            reduce_acc = reduce_mean(acc, args.nprocs)
            accuarcies.append(reduce_acc.detach().cpu())
            # print(pre)

    if args.local_rank == 0:
        args.logging.info(f'[{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")} TESTVAL]' + \
            f' acc: {mean(accuarcies)}')
    return mean(accuarcies)
        
def get_loaders(args, train_dataset, val_dataset):
    if "visual_domain_decathlon" in args.dataset_name:
        train_transform, test_transform = create_decathlon_transforms(args)
    else:
        train_transform, test_transform = create_transforms(args)
    train_sampler = DistributedSampler(train_dataset)
    train_loader = DataLoader(
        dataset=train_dataset, num_workers=args.threads, 
        batch_size=args.batchSize // args.nprocs, 
        pin_memory=False, sampler=train_sampler)
    
    test_sampler = DistributedSampler(val_dataset)
    test_loader = DataLoader(
        dataset=val_dataset, num_workers=args.threads, 
        batch_size=args.batchSize // args.nprocs, 
        sampler=test_sampler)
    
    return train_loader, train_sampler, test_loader, test_sampler
    

if __name__ == '__main__':
    args = get_args()
    
    if args.local_rank == 0:
        now = datetime.datetime.now()
        logging_filename = LOG_DIR + args.dataset_name.replace("/", "_") + '_' + args.logname + '_' + now.strftime("%Y-%m-%d-%H")+'.log'
        print(f'===> Logging to {logging_filename}') 
        logging.basicConfig(filename=logging_filename, level=logging.INFO, filemode="w")
        args.logging = logging
    
    args.nprocs = torch.cuda.device_count()
    if args.local_rank == 0:
        logging.info(f'Running with GPUs and the number of GPUs: {args.nprocs}')
        logging.info(args)
    
    init_method = 'tcp://' + args.ip + ':' + args.port
    cudnn.benchmark = True
    dist.init_process_group(backend='nccl',
                            init_method=init_method,
                            world_size=args.nprocs,
                            timeout=datetime.timedelta(seconds=100),
                            rank=args.local_rank)
    init_seeds(43)
    
    args.gpuNums = torch.cuda.device_count()
    args.image_shape = None
    args.use_ps = True
    args.fix_mask = False
    
    if args.local_rank == 0:
        logging.info('===> Loading the dataloader')
    
    if "tf_visual_domain_decathlon" in args.dataset_name:
        train_sets, val_sets, num_classes, task_names = load_tf_visual_domain_decathlon_benchmark(args)
    elif "visual_domain_decathlon" in args.dataset_name:
        train_sets, val_sets, num_classes, task_names = load_visual_domain_decathlon_benchmark(args)
    else:
        train_sets, val_sets, num_classes, task_names = load_imagenet2sketch_benchmark(args)
    
    if args.local_rank == 0:
        logging.info('===> Building the model union')
    
    output_layer_name = None
    if "wide_resnet" in args.backbone:
        output_layer_name = "linears.0"
    elif "resnet" in args.backbone:
        output_layer_name = "fc"
    elif "vit" in args.backbone:
        output_layer_name = "head"
    elif "densenet" in args.backbone:
        output_layer_name = "classifier"
    
    args.num_classes = 1000
    if args.mask_lr < 0:
        args.mask_lr = 0.02
    root_model = build_model(args)
    ps_load_state_dict(root_model, args.chkname, prefix=output_layer_name)
    population = ModelUnion(root_model, train_sets, val_sets, num_classes, task_names)
    if args.resume_from > 0:
        population.load_models(args.save_path + f"_{args.resume_from - 1}", prefix=output_layer_name)
    for iter in range(args.resume_from ,args.num_iterations):
        for task_idx in range(1, len(population.train_sets)):
            
            args.fix_mask = False
            context = population.get_task_context(args, task_idx, prefix=output_layer_name)
            model, train_set, val_set, task_name = context["model"], context["train_set"], context["val_set"], context["task"]
            train_loader, train_sampler, test_loader, test_sampler = get_loaders(args, train_set, val_set)
            
            if "visual_domain_decathlon" in args.dataset_name:
                # decathlon_datasets.get_submit_json(args, population)
                args.wd = decathlon_datasets.VISUAL_DECATHLON_WDS[task_name]
            
            model.cuda(args.local_rank)
            model = nn.parallel.DistributedDataParallel(
                model, device_ids=[args.local_rank], find_unused_parameters=True)
            
            opt = build_optimizer(args, [p for p in model.parameters() if p.requires_grad])
            lr_schedular = build_lr_schedular(args, opt)
            
            if args.local_rank == 0:
                logging.info('===' * 50)
                logging.info(f'===> Starting task {task_name} prepare training')
                logging.info(f'===> Model trainable parameters: {sum([p.numel() for p in model.parameters() if p.requires_grad])}')
                logging.info(f'===> Model untrainable parameters: {sum([p.numel() for p in model.parameters() if not p.requires_grad])}')
            
            for epoch in tqdm(range(args.nEpochs // 4)):
                train_sampler.set_epoch(epoch) 
                train_one_epoch(args, epoch, model, opt, lr_schedular, train_loader)
                with torch.no_grad():
                    test_one_epoch(args, model, test_loader)
                    
            model = model.module
            model.to("cpu")
            visual_data = population.convert_ps_model(model, task_idx, prefix=output_layer_name)
            
            if args.local_rank == 0 and args.visual_file:
                with open(args.visual_file, "a+") as f:
                    f.write(f"Data of {task_name}, weight from:\n")
                    f.write(visual_data)
            
            model.cuda(args.local_rank)
            model = nn.parallel.DistributedDataParallel(
                model, device_ids=[args.local_rank],find_unused_parameters=True)
            
            opt = build_optimizer(args, [
                                      {'params': [p for n, p in model.named_parameters() if p.requires_grad and "ps" not in n]},
                                      {'params': [p for n, p in model.named_parameters() if p.requires_grad and ("ps_" in n and "ps_mask" not in n)],
                                       'lr': args.lr * 2},
                                      {'params': [p for n, p in model.named_parameters() if p.requires_grad and "ps_mask" in n],
                                       'lr': args.mask_lr}
                                      ]
                                  )
            lr_schedular = build_lr_schedular(args, opt)
            
            if args.local_rank == 0:
                logging.info(f'===> Starting task {task_name} post training')
                logging.info(f'===> Model trainable parameters: {sum([p.numel() for p in model.parameters() if p.requires_grad])}')
                logging.info(f'===> Model untrainable parameters: {sum([p.numel() for p in model.parameters() if not p.requires_grad])}')
            
            best_acc = 0.0
            best_model = None
            for epoch in tqdm(range(args.nEpochs)):
                train_sampler.set_epoch(epoch) 
                train_one_epoch(args, epoch, model, opt, lr_schedular, train_loader)
                with torch.no_grad():
                    acc = test_one_epoch(args, model, test_loader)
                    if args.fix_mask:
                        if best_acc < acc:
                            best_acc = acc
                            best_model = deepcopy(model.module)
                            best_model.to("cpu")
                if epoch == args.p_T - 1:
                    fix_TConv_mask(model.module)
                    args.fix_mask = True
                    # opt.param_groups[1]["lr"] = 0.0
                    if args.local_rank == 0:
                        logging.info(f"Model shared parameters: {shared_params}")
                    
                    
                if epoch % 10 == 0 and args.local_rank == 0:
                    shared_params = ps_visualize(model.module)
                if args.visual_file and epoch == args.nEpochs - 1 and args.local_rank == 0:
                    with open(args.visual_file, "a+") as f:
                        f.write("Weight mask:\n")
                        for n,m in model.module.named_modules():
                            if isinstance(m, ps.TConv2d):
                                f.write(f"{n}: {(torch.sigmoid(m.ps_mask) > SHARE_THRESHOLD).int()}\n")
                                # break
                        f.write(f"{task_name} end\n")
                     
            population.add_model(best_model, task_idx=task_idx, acc=best_acc)
        if args.local_rank == 0:
            info = "".join([f'{task} acc: {task_acc}; ' for task, task_acc in zip(population.task_names, population.accs)])
            logging.info(f'===> Iter {iter} | {info}')
        population.save_models(args.save_path + f"_{iter}")
        