import argparse
import numpy as np
import math
import sys
sys.path.append('../share')
sys.path.append('../model')
import os
from pathlib import Path
import torch
import torch.backends.cudnn as cudnn
from torch.utils.tensorboard import SummaryWriter

import timm.optim.optim_factory as optim_factory
import util.lr_sched as lr_sched

from models_pospred import MAGECityPosition, MAGECityPosition3D, MAGECityPositionRoad
from dataloader import PolyDataset2D, PolyDataset3D, PolyDatasetRoad
from random_mask import random_masking

def get_args_parser():
    parser = argparse.ArgumentParser('MAE pre-training', add_help=False)
    parser.add_argument('--batch_size', default=256, type=int,
                        help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')
    parser.add_argument('--epochs', default=1000, type=int)
    parser.add_argument('--accum_iter', default=1, type=int,
                        help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')
    parser.add_argument('--weight_decay', type=float, default=0.05,
                        help='weight decay (default: 0.05)')
    parser.add_argument('--lr', type=float, default=1e-4, metavar='LR',
                        help='learning rate (absolute lr)')
    parser.add_argument('--blr', type=float, default=1e-6, metavar='LR',
                        help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
    parser.add_argument('--min_lr', type=float, default=0., metavar='LR',
                        help='lower lr bound for cyclic schedulers that hit 0')
    parser.add_argument('--warmup_epochs', type=int, default=40, metavar='N',
                        help='epochs to warmup LR')
    parser.add_argument('--device', default='cuda',
                        help='device to use for training / testing')
    parser.add_argument('--seed', default=0, type=int)
    parser.add_argument('--resume', default='',
                        help='resume from checkpoint')
    
    parser.add_argument('--data_path', default='../datasets/statespoly', type=str,
                        help='dataset path')
    parser.add_argument('--output_dir', default='../results/mae/output_dir',
                        help='path where to save, empty for no saving')
    parser.add_argument('--log_dir', default='../results/mae/output_log',
                        help='path where to tensorboard log')
    
    parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
                        help='start epoch')
    parser.add_argument('--num_workers', default=20, type=int)
    parser.add_argument('--pin_mem', action='store_true',
                        help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
    parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')

    parser.add_argument('--split_ratio', type = float, default=0.8)
    parser.add_argument('--trans_deep', type = int, default=6)
    parser.add_argument('--trans_deep_decoder', type = int, default=3)
    parser.add_argument('--num_heads', type = int, default=8)

    parser.add_argument('--save_freq', type = int, default=50)
    parser.add_argument('--embed_dim', type = int, default=256)
    parser.add_argument('--decoder_embed_dim', type = int, default=16)
    parser.add_argument('--drop_ratio', type = float, default=0.1)

    parser.add_argument('--remain_num', type = int, default=6)
    parser.add_argument('--max_poly', type = int, default=20)
    parser.add_argument('--max_build', type = int, default=60)
    parser.add_argument('--discre', type = int, default=50)
    parser.add_argument('--patch_num', type = int, default=10)
    parser.add_argument('--patch_size', type = int, default=5)
    
    parser.add_argument('--pos_weight', type = float, default=100)
    parser.add_argument('--patchify', action= "store_true")
    parser.add_argument('--ablation', action= "store_true")
    
    parser.add_argument('--var_remain_num', action= "store_true")
    parser.add_argument('--model_type', default='2d', type=str) # '2d', '3d', 'road'
    
    parser.set_defaults(pin_mem=True)

    return parser

def main(args):
    device = torch.device(args.device)

    seed = args.seed
    torch.manual_seed(seed)
    np.random.seed(seed)
    cudnn.benchmark = True


    if args.model_type == '2d':
        dataset_train = PolyDataset2D(args.data_path, train=True,split_ratio = args.split_ratio)
        dataset_valid = PolyDataset2D(args.data_path, train=False,split_ratio = args.split_ratio)
    elif args.model_type == '3d':
        dataset_train = PolyDataset3D(args.data_path, train=True,split_ratio = args.split_ratio)
        dataset_valid = PolyDataset3D(args.data_path, train=False,split_ratio = args.split_ratio)
    elif args.model_type == 'road':
        dataset_train = PolyDatasetRoad(args.data_path, train=True,split_ratio = args.split_ratio)
        dataset_valid = PolyDatasetRoad(args.data_path, train=False,split_ratio = args.split_ratio)
    else:
        raise NotImplementedError

    if args.log_dir is not None:
        os.makedirs(args.log_dir, exist_ok=True)
        log_writer = SummaryWriter(log_dir=args.log_dir)
    else:
        log_writer = None

    data_loader_train = torch.utils.data.DataLoader(
        dataset_train, 
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=args.pin_mem,
        drop_last=True,
    )
    data_loader_valid = torch.utils.data.DataLoader(
        dataset_valid, 
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=args.pin_mem,
        drop_last=True,
    )

    if args.model_type == '2d':
        model = MAGECityPosition(embed_dim=args.embed_dim, depth=args.trans_deep, num_heads=args.num_heads, 
                                 decoder_embed_dim=args.decoder_embed_dim, decoder_depth=args.trans_deep_decoder,   
                                 decoder_num_heads=args.num_heads, drop_ratio = args.drop_ratio, pos_weight = args.pos_weight,
                                 discre = args.discre, patch_size = args.patch_size, patch_num = args.patch_num, 
                                 device = args.device, ablation = args.ablation, patchify = args.patchify)
    elif args.model_type == '3d':
        model = MAGECityPosition3D(embed_dim=args.embed_dim, depth=args.trans_deep, num_heads=args.num_heads, 
                                 decoder_embed_dim=args.decoder_embed_dim, decoder_depth=args.trans_deep_decoder,   
                                 decoder_num_heads=args.num_heads, drop_ratio = args.drop_ratio, pos_weight = args.pos_weight,
                                 discre = args.discre, patch_size = args.patch_size, patch_num = args.patch_num, 
                                 device = args.device)
    elif args.model_type == 'road':
        model = MAGECityPositionRoad(embed_dim=args.embed_dim, depth=args.trans_deep, num_heads=args.num_heads, 
                                 decoder_embed_dim=args.decoder_embed_dim, decoder_depth=args.trans_deep_decoder,   
                                 decoder_num_heads=args.num_heads, drop_ratio = args.drop_ratio, pos_weight = args.pos_weight,
                                 discre = args.discre, patch_size = args.patch_size, patch_num = args.patch_num, 
                                 device = args.device)
    else:
        raise NotImplementedError
    
    model.to(device)
    model_without_ddp = model
    eff_batch_size = args.batch_size * args.accum_iter 
    
    if args.lr is None:  
        args.lr = args.blr * eff_batch_size / 256

    param_groups = optim_factory.param_groups_weight_decay(model_without_ddp, args.weight_decay)
    optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95))
    
    remain_num = args.remain_num

    best_valid_loss = 1000000
    train_num = 0
    for epoch in range(args.start_epoch, args.epochs):
        model.train(True)
        optimizer.zero_grad()

        for data_iter_step, data_train in enumerate(data_loader_train):
            
            if args.model_type == '2d':
                samples, pos, info = data_train
            elif args.model_type == '3d':
                samples, pos, info, h = data_train
            elif args.model_type == 'road':
                samples, pos, info, road = data_train
            else:
                raise NotImplementedError
            
            if data_iter_step % args.accum_iter == 0: 
                lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader_train) + epoch, args)

            if args.var_remain_num:
                remain_num = min(int(torch.randint(0, args.remain_num, (1,))), int(torch.min(info[:, 0])))
            
            if args.model_type == '3d':
                poly_reserve, pos_reserve, h_reserve, _, pos_tar, _, _ = random_masking(samples, pos, info, remain_num, h=h, max_build = args.max_build)
            else:
                poly_reserve, pos_reserve, _, pos_tar, _ = random_masking(samples, pos, info, remain_num, max_build = args.max_build)
                

            poly_reserve = poly_reserve.to(device)
            pos_reserve = pos_reserve.to(device)

            if args.model_type == '2d':
                loss, _ = model(poly_reserve, pos_reserve, pos_tar)
            elif args.model_type == '3d':
                h_reserve = h_reserve.to(device)
                loss, _ = model(poly_reserve, pos_reserve, h_reserve, pos_tar)
            elif args.model_type == 'road':
                road = road.to(device).float()
                loss, _ = model(poly_reserve, pos_reserve, pos_tar, road)
            else:
                raise NotImplementedError
            

            loss_value = loss.item()
            
            if not math.isfinite(loss_value):
                print("Loss is {}, stopping training".format(loss_value))
                sys.exit(1)
            loss /= args.accum_iter
            if (data_iter_step + 1) % args.accum_iter == 0:
                optimizer.zero_grad()

            loss.backward()
            optimizer.step()

            log_writer.add_scalar('loss_train', loss.item(), train_num)

            train_num+=1


        print('train_loss:', loss.item()) 

        valid_loss = 0
        valid_count = 0
        for valid_step, data_valid in enumerate(data_loader_valid):
            model.eval()
            
            if args.model_type == '2d':
                samples, pos, info = data_valid
            elif args.model_type == '3d':
                samples, pos, info, h = data_valid
            elif args.model_type == 'road':
                samples, pos, info, road = data_valid
            else:
                raise NotImplementedError
            
            if args.var_remain_num:
                remain_num = min(int(torch.randint(0, args.remain_num, (1,))), int(torch.min(info[:, 0])))
            
            if args.model_type == '3d':
                poly_reserve, pos_reserve, h_reserve, _, pos_tar, _, _ = random_masking(samples, pos, info, remain_num, h=h, max_build = args.max_build)
            else:
                poly_reserve, pos_reserve, _, pos_tar, _ = random_masking(samples, pos, info, remain_num, max_build = args.max_build)
                

            poly_reserve = poly_reserve.to(device)
            pos_reserve = pos_reserve.to(device)

            if args.model_type == '2d':
                loss, _ = model(poly_reserve, pos_reserve, pos_tar)
            elif args.model_type == '3d':
                h_reserve = h_reserve.to(device)
                loss, _ = model(poly_reserve, pos_reserve, h_reserve, pos_tar)
            elif args.model_type == 'road':
                road = road.to(device).float()
                loss, _ = model(poly_reserve, pos_reserve, pos_tar, road)
            else:
                raise NotImplementedError

            valid_loss += loss.item()
            valid_count += 1         
                  
        val_loss = valid_loss/valid_count

        print('epoch:', epoch, 'val_loss: ', val_loss)
        
        log_writer.add_scalar('loss_valid', val_loss, train_num)

        if val_loss < best_valid_loss:
            best_valid_loss = val_loss
            if not os.path.exists(args.output_dir):
                os.makedirs(args.output_dir)
            model_path = os.path.join(args.output_dir, f'positionpred_best.pth')
            torch.save(model.state_dict(), model_path)

        if epoch%args.save_freq == 0:
            model_fpath = os.path.join(args.output_dir, f'positionpred_{epoch}.pth')
            torch.save(model.state_dict(), model_fpath)

    model_path = os.path.join(args.output_dir, f'positionpred_final.pth')
    torch.save(model.state_dict(), model_path)

if __name__ == '__main__':
    args = get_args_parser()
    args = args.parse_args()
    if args.output_dir:
        Path(args.output_dir).mkdir(parents=True, exist_ok=True)
    main(args)