import timm
import timm.optim.optim_factory as optim_factory
import argparse
import math
import torch.multiprocessing as mp
import torch.distributed as dist

from dataset import *
from model import *
import torch_geometric
import util.dist_helper as misc

parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, default='config.yaml')
parser.add_argument('--data_path', type=str, default=r'../data/skyscript')
parser.add_argument('--graph_out_dim', type=int, default=128)
parser.add_argument('--mask_rate', type=float, default=0.75)

parser.add_argument('--lr', type=float, default=None)
parser.add_argument('--blr', type=float, default=1.e-4)
parser.add_argument('--min_lr', type=float, default=1.e-6, metavar='LR',
                        help='lower lr bound for cyclic schedulers that hit 0')
parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N',
                        help='epochs to warmup LR')

parser.add_argument('--batch_size', type=int, default=660)
parser.add_argument('--total_epoch', type=int, default=60)
parser.add_argument('--start_epoch', type=int, default=0)
parser.add_argument('--cos', type=bool, default=False)
parser.add_argument('--weight_decay', type=float, default=0.05,
                    help='weight decay (default: 0.05)')
parser.add_argument('--save_root_dir', type=str, default=r'./savedir')
parser.add_argument('--save_dir', type=str, default=r'GeoLink-60e')
parser.add_argument('--resume', default=r'', help='resume from checkpoint')


parser.add_argument('--world_size', type=int, default=1, help='number of distributed processes')
parser.add_argument('--dist_backend', type=str, default='nccl', help='distributed backend')
parser.add_argument('--dist_url', type=str, default='env://', help='url used to set up distributed training')
parser.add_argument('--rank', type=int, default=0, help='rank of the current process')
parser.add_argument('--local_rank', type=int, default=0, help='local rank of the current process')

args = parser.parse_args()


def adjust_learning_rate(optimizer, epoch, args):
    """Decay the learning rate with half-cycle cosine after warmup"""
    if epoch <= args.warmup_epochs:
        lr = args.lr * epoch / args.warmup_epochs + args.min_lr
    else:
        lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \
            (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.total_epoch - args.warmup_epochs)))
    for param_group in optimizer.param_groups:
        if "lr_scale" in param_group:
            param_group["lr"] = lr * param_group["lr_scale"]
        else:
            param_group["lr"] = lr
    return optimizer

def main_worker(gpu, ngpus_per_node, args):
    args.local_rank = gpu
    args.rank = gpu
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    os.environ['WORLD_SIZE'] = str(args.world_size)
    os.environ['RANK'] = str(args.rank)
    torch.manual_seed( 66 + args.rank)

    dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size,
                            rank=args.rank)

    device = torch.device(f'cuda:{args.local_rank}')
    torch.cuda.set_device(device)

    if not os.path.exists(args.save_root_dir) and args.rank == 0:
        os.mkdir(args.save_root_dir)
    save_dir = os.path.join(args.save_root_dir, args.save_dir)
    if not os.path.exists(save_dir) and args.rank == 0:
        os.mkdir(save_dir)
    args.save_dir = save_dir

    scaler = torch.amp.GradScaler(enabled=False, device='cuda')
    torch.set_float32_matmul_precision('high')

    # train records
    record_path = os.path.join(save_dir, 'loss_record.txt')

    train_set = ImageOSMHeterDataset(data_path=args.data_path)
    print("Data size: ", len(train_set))
    train_sampler = torch.utils.data.distributed.DistributedSampler(train_set, shuffle=True)
    train_loader = torch_geometric.loader.DataLoader(train_set, batch_size=args.batch_size, shuffle=False, num_workers=6, drop_last=True,
                               pin_memory=True, sampler=train_sampler, prefetch_factor=2)

    # image encoder
    img_encoder = timm.create_model('vit_large_patch16_224', pretrained=False, num_classes=0, global_pool='', drop_path_rate=0.1)

    # osm encoder
    osm_encoder = OSMHeteroGAT(hidden_chans=256, out_chans=args.graph_out_dim)

    # fusion model
    model = GeoLink(img_encoder, osm_encoder).to(device)
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True)
    model_without_ddp = model.module

    # optimizer
    if args.lr is None:  # only base_lr is specified
        args.lr = args.blr * args.batch_size * args.world_size / 256
    param_groups = optim_factory.param_groups_weight_decay(model_without_ddp, args.weight_decay)
    optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95), weight_decay=args.weight_decay)
    print(optimizer)

    # model resume
    misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer)

    # train
    for epoch in range(args.start_epoch, args.total_epoch):
        loss_step1 = 0
        loss_step2 = 0
        loss_step3 = 0
        train_loader.sampler.set_epoch(epoch)
        model.train()
        # optimizer = adjust_learning_rate(optimizer, epoch, args)
        for step, data in enumerate(train_loader):
            model.zero_grad()
            optimizer = adjust_learning_rate(optimizer, step/len(train_loader) + epoch, args)

            osm, img_o = data[0], data[1]
            osm = osm.to(device)
            # img_t = img_t.to(device)
            img_o = img_o.to(device)

            with torch.amp.autocast(dtype=torch.bfloat16,device_type='cuda'):
                img_rec_loss, graph_loss, node_loss, pred_imgs, mask_imgs, pred_osms, mask_osms = model(img_o, osm,
                                                                                               img_mask_ratio=0.75,
                                                                                               osm_mask_ratio=0.20)

                loss = img_rec_loss + graph_loss * 0.01 + node_loss * 0.01


            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            torch.cuda.synchronize()

            loss_step1 += img_rec_loss.item()
            loss_step2 += graph_loss.item()
            loss_step3 += node_loss.item()
            if step % 50 == 0 and args.rank == 0:
                print(
                    f"epoch [{epoch}/{args.total_epoch}]\t Step [{step}/{len(train_loader)}]\t lr [{optimizer.param_groups[0]['lr']}]\n"
                    f"Loss1: {loss_step1 / 50}\t Loss2: {loss_step2 / 50}\t Loss3: {loss_step3 / 50}\t")
                with open(record_path, 'a') as record_file:
                    record_file.write(
                        f"epoch [{epoch}/{args.total_epoch}]\t Step [{step}/{len(train_loader)}]\t [{optimizer.param_groups[0]['lr']}]\n "
                        f"Loss1: {loss_step1 / 50}\t Loss2: {loss_step2 / 50}\t Loss3: {loss_step3 / 50}\n")
                loss_step1 = 0
                loss_step2 = 0
                loss_step3 = 0


        if ((epoch+1) % 5 == 0 or (epoch+1) == args.total_epoch) and args.rank == 0:
            misc.save_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, epoch=epoch)


if __name__ == "__main__":
    # mp.set_sharing_strategy('file_descriptor')
    args.world_size = torch.cuda.device_count()
    mp.spawn(main_worker, nprocs=args.world_size, args=(args.world_size, args))
