import os, time, pprint
import argparse
import warnings
warnings.filterwarnings("ignore")
import torch
import dataset.datasets_flow as datasets
from DCFlow import *
from utils.utils import *
from utils.loss import *
from utils.flow_utils import *
from utils.flow_aug import *
from torch.utils.tensorboard import SummaryWriter
from utils.flow_utils import *
torch.backends.cuda.preferred_linalg_library("cusolver")


def train(args):
    device = torch.device("cuda:"+ str(args.gpuid))
    train_loader = datasets.fetch_dataloader(args, split="train")
        
    model = DCFlow(args=args).to(device)
    model.train()

    print(f"{round(count_parameters(model)/1000000, 2)}M parameters")
    optimizer = torch.optim.AdamW(list(model.parameters()), lr=args.lr, weight_decay=1e-5, eps=1e-8)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=args.lr, total_steps=(args.num_steps+100),
                pct_start=0.2, cycle_momentum=False, anneal_strategy='linear', div_factor=25, final_div_factor=1)
    
    loss_function = loss_func(args.loss_type)
    if args.loss_type == "perceptual": loss_function = loss_function.to(device)
    if args.cross_aug:
        flow_aug_func = FlowAugmentorTensor(max_shift=args.cross_aug_shift, max_scale=args.cross_aug_scale, max_rotate_deg=args.cross_aug_rot)
    
    writer = SummaryWriter(args.log_full_dir)
    sum_epe11, sum_epe12, sum_epe22, glob_iter = 0, 0, 0, 0
    start_time = time.time()

    while glob_iter <= args.num_steps:
        for i, data_batch in enumerate(train_loader):
            end_time = time.time() # calculate time remaining

            for key, value in data_batch.items(): data_batch[key] = data_batch[key].to(device)
            optimizer.zero_grad()

            pred_flow_A2self = model(data_batch["imgA"], data_batch["imgA_self"], "trainAA")
            pred_flow_B2self = model(data_batch["imgB"], data_batch["imgB_self"], "trainBB")
            flow_loss_A2self = trimmed_flow_sequence_loss(pred_flow_A2self, data_batch["flow_A2self"], data_batch["mask_A2self"], gamma=0.8, trim=args.self_trim)
            flow_loss_B2self = trimmed_flow_sequence_loss(pred_flow_B2self, data_batch["flow_B2self"], data_batch["mask_B2self"], gamma=0.8, trim=args.self_trim)
                
            pseudo_imgA, pred_flow_B2A = model(data_batch["imgB"], data_batch["imgA"], "trainBA")
            pseudo_imgA_warp = warp_image_with_flow(pseudo_imgA, pred_flow_B2A[-1].clone().detach())
            trans_mask = torch.logical_and(pseudo_imgA_warp!= 0, data_batch["imgB"]!= 0).all(dim=1).unsqueeze(1)
            trans_loss = args.trans_weight * loss_function(pseudo_imgA_warp * trans_mask, data_batch["imgB"] * trans_mask).mean()
            
            flow_loss_aug = torch.zeros(1, device=device)
            if args.cross_aug and glob_iter > args.cross_aug_start:
                imgB_aug, imgA_aug, pseudo_imgA_aug_gt, flow_aug_gt, valid_aug_gt = \
                    flow_aug_func(data_batch["imgB"], data_batch["imgA"], pseudo_imgA.clone().detach(), pred_flow_B2A[-1].clone().detach(), valid=torch.ones_like(data_batch["mask_A2self"]))
                pseudo_imgA_aug, pred_flow_aug = model(imgB_aug, imgA_aug, "trainBA_aug")
                flow_loss_aug = args.flow_aug_weight * trimmed_flow_sequence_loss(pred_flow_aug, flow_aug_gt, valid_aug_gt, gamma=0.8, trim=args.cross_aug_trim)

            epe11 = calculate_epe_batch(pred_flow_A2self[-1], data_batch["flow_A2self"], data_batch["mask_A2self"])
            epe22 = calculate_epe_batch(pred_flow_B2self[-1], data_batch["flow_B2self"], data_batch["mask_B2self"])
            epe12 = calculate_epe_batch(pred_flow_B2A[-1], data_batch["flow_B2A"], data_batch["mask_B2A"])

            # train two networks
            loss = flow_loss_A2self + flow_loss_B2self + trans_loss + flow_loss_aug
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
            optimizer.step()
            scheduler.step()

            writer.add_scalar("Loss/flow_loss_A2self", flow_loss_A2self.item(), glob_iter)
            writer.add_scalar("Loss/flow_loss_B2self", flow_loss_B2self.item(), glob_iter)
            writer.add_scalar("Loss/trans_loss", trans_loss.item(), glob_iter)
            writer.add_scalar("Loss/flow_loss_aug", flow_loss_aug.item(), glob_iter)
            
            # calculate metric
            sum_epe11 += epe11
            sum_epe12 += epe12
            sum_epe22 += epe22
            
            if glob_iter % args.print_freq == 0 and glob_iter != 0:
                time_remain = (end_time - start_time) * (args.num_steps - glob_iter) / glob_iter
                print("Training: Iter[{:0>3}]/[{:0>3}] epe12: {:.3f} epe11: {:.3f} epe22: {:.3f} lr={:.8f} time: {:.2f}h".format(glob_iter, 
                                                                                                                                    args.num_steps, 
                                                                                                                                    sum_epe12 / args.print_freq, 
                                                                                                                                    sum_epe11 / args.print_freq, 
                                                                                                                                    sum_epe22 / args.print_freq, 
                                                                                                                                    scheduler.get_lr()[0], 
                                                                                                                                    time_remain/3600))
                
                sum_epe11, sum_epe12, sum_epe22 = 0, 0, 0

            # save model
            if glob_iter % args.save_freq == 0 and glob_iter != 0 and not args.nolog:
                filename = "model" + "_iter_" + str(glob_iter) + ".pth"
                model_save_path = os.path.join(args.log_full_dir, filename)
                checkpoint = {"DCFlow": model.state_dict()}
                torch.save(checkpoint, model_save_path)
                args.checkpoints = model_save_path
            
            if glob_iter % args.val_freq == 0 and glob_iter != 0:
                test(args, glob_iter, model)            
            glob_iter += 1
            if glob_iter > args.num_steps: break
    
    writer.close()

            
def test(args, glob_iter=None, model=None):
    device = torch.device("cuda:"+ str(args.gpuid))
    test_loader = datasets.fetch_dataloader(args, split="test")
    if model == None:
        model = DCFlow(args=args).to(device)
        if args.checkpoints is None:
            print("ERROR : no checkpoints")
            exit()
        state = torch.load(args.checkpoints, map_location=device)
        model.load_state_dict(state["DCFlow"])
        print("test with pretrained model")
    model.eval()

    with torch.no_grad():
        epe_list = np.array([], dtype=np.float32)
        num_valid_pixels = 0
        out_valid_pixels = 0
        for i, data_batch in enumerate(test_loader):
            for key, value in data_batch.items(): 
                if type(data_batch[key]) == torch.Tensor: data_batch[key] = data_batch[key].to(device)
            pseudo_imgA_eval, pred_flow_B2A = model(data_batch["imgB"], data_batch["imgA"], "test")
            # calculate metric
            bs_epe_list, bs_out_valid_pixels, bs_num_valid_pixels = calulate_metric_val(pred_flow_B2A[-1], data_batch["flow_B2A"], data_batch["mask_B2A"])
            
            epe_list = np.append(epe_list, bs_epe_list)
            out_valid_pixels += bs_out_valid_pixels
            num_valid_pixels += bs_num_valid_pixels
            print("Iter: %d, epe: %f, out_valid_pixels: %d, num_valid_pixels: %d." % (
                i, np.mean(bs_epe_list), bs_out_valid_pixels, bs_num_valid_pixels))
        
        epe = np.mean(epe_list)
        f1 = 100 * out_valid_pixels / num_valid_pixels
        print("========================")
        print("Validation metric: epe: %f, F1: %f." % (epe, f1))
    model.train()


def main():
    parser = argparse.ArgumentParser()
    ## mode
    parser.add_argument("--mode", type=str, default="train", help="Train or test", choices=["train", "test"])
    parser.add_argument("--checkpoints", type=str, help="Test model name")
    parser.add_argument("--dataset", type=str, default="MS2_TIR2RGB", help="dataset")
    ## model
    parser.add_argument("--trans", type=str, default="UNet", help="modality transfer network")
    parser.add_argument("--model_flow", type=str, default="RAFT", help="flow estimation network")
    ## setting
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--gpuid", type=int, default=0)
    parser.add_argument("--log_dir", type=str, default="logs", help="The log path")
    parser.add_argument("--log_full_dir", type=str)
    parser.add_argument("--nolog", action="store_true", default=False, help="save log file or not")
    parser.add_argument("--note", type=str, default="", help="experiment notes")
    parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
    ## training
    parser.add_argument("--lr", type=float, default=4e-4, help="Max learning rate")
    parser.add_argument("--batch_size", type=int, default=4)
    parser.add_argument("--print_freq", type=int, default=100)
    parser.add_argument("--save_freq", type=int, default=5000)
    parser.add_argument("--val_freq", type=int, default=5000)
    parser.add_argument("--num_steps", type=int, default=30000)
    ## loss
    parser.add_argument("--loss_type", type=str, default="perceptual")
    parser.add_argument("--trans_weight", type=float, default=2.0)
    parser.add_argument("--self_trim", type=float, default=0.2)
    parser.add_argument("--cross_aug", action="store_true", help="cross modal augmentation loss")
    parser.add_argument("--flow_aug_weight", type=float, default=0.05)
    parser.add_argument("--cross_aug_trim", type=float, default=0.8)
    parser.add_argument("--cross_aug_start", type=int, default=10000)
    parser.add_argument("--cross_aug_shift", type=float, default=24)
    parser.add_argument("--cross_aug_scale", type=float, default=0.05)
    parser.add_argument("--cross_aug_rot", type=float, default=3)
    args, remaining_args = parser.parse_known_args()
    if args.model_flow == 'RAFT':
        parser.add_argument('--small', action='store_true', help='use small model')
        parser.add_argument('--iters', type=int, default=6)
    args = parser.parse_args()
    
    if not args.nolog:
        args.log_full_dir = os.path.join(args.log_dir, time.strftime("%Y-%m-%d-%H_%M_%S", time.localtime()) + "_" + args.dataset + "_" + args.note)
        if not os.path.exists(args.log_full_dir): os.makedirs(args.log_full_dir)
        sys.stdout = Logger_(os.path.join(args.log_full_dir, f"record.log"), sys.stdout)
    pprint.pprint(vars(args))
    
    seed_everything(args.seed)
    
    if args.mode == "train": train(args)
    else: test(args)

if __name__ == "__main__":
    main()
