import os
import time
import einops
import sys
import cv2
import numpy as np
import utils as ut
import config as cg
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
import torch.multiprocessing as mp
# from tensorboardX import SummaryWriter
from torch.utils.tensorboard import SummaryWriter
from argparse import ArgumentParser
from model.model import SlotAttentionAutoEncoder
from model.unet import UNet
import random
from utils import save_on_master
from eval import eval
import math

def main(args):
    ut.init_distributed_mode(args)

    print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
    print("{}".format(args).replace(', ', ',\n'))

    device = torch.device(args.device)

    # fix the seed for reproducibility
    seed = args.seed + ut.get_rank()
    random.seed(seed)
    torch.manual_seed(seed)
#     torch.cuda.seed(seed)
    np.random.seed(seed)

    cudnn.benchmark = True
    print(f"GPU number:{torch.cuda.device_count()}")
    print(f"world_size:{ut.get_world_size()}")
    num_tasks = ut.get_world_size()
    lr = args.lr
    epsilon = 1e-6
    num_slots = args.num_slots
    iters = args.num_iterations 
    batch_size = args.batch_size 
    warmup_it = int(args.warmup_steps)
    decay_step = int(args.decay_steps)
    num_it = int(args.num_train_steps)
    resume_path = args.resume_path
    attn_drop = args.attn_drop
    num_o = args.num_o
    num_t = args.num_t
    fixed_query = args.fixed_query
    temporal_cons = args.temporal_cons
    replicate = args.replicate
    static = args.static
    gap = args.gap
    hid_dim = args.hidden_dim
    mse_scale = args.mse_scale
    out_channel = 3 if args.flow_to_rgb else 2
    # args.resolution = (128, 224)
    args.resolution = (192, 384)
    

    # setup log and model path, initialize tensorboard,
    [logPath, modelPath, resultsPath] = cg.setup_path(args)
    print(logPath)
    # writer = SummaryWriter(logPath)

    # initialize dataloader (validation bsz has to be 1 for FBMS, because of different resolutions, otherwise, can be >1)
    trn_dataset, val_dataset, resolution, in_out_channels, use_flow, loss_scale, ent_scale, cons_scale = cg.setup_dataset(args)
    
    if True:  # args.distributed:
        num_tasks = ut.get_world_size()
        global_rank = ut.get_rank()
        sampler_train = torch.utils.data.DistributedSampler(
            trn_dataset, num_replicas=num_tasks, rank=global_rank, shuffle=True
        )
        print("Sampler_train = %s" % str(sampler_train))
        sampler_val = torch.utils.data.RandomSampler(val_dataset)
    else:
        sampler_train = torch.utils.data.RandomSampler(trn_dataset)
        sampler_val = torch.utils.data.RandomSampler(val_dataset)

    if global_rank == 0 and logPath is not None:
        os.makedirs(logPath, exist_ok=True)
        writer = SummaryWriter(logPath)
    else:
        writer = None

    trn_loader = ut.FastDataLoader(
        trn_dataset, sampler=sampler_train, num_workers=8, batch_size=batch_size, 
        pin_memory=True, drop_last=True,
        multiprocessing_context="fork")
    val_loader = ut.FastDataLoader(
        val_dataset, num_workers=8, batch_size=1, shuffle=False, 
        pin_memory=True, drop_last=False,
        multiprocessing_context="fork")


    model = SlotAttentionAutoEncoder(resolution=resolution,
                                     num_slots=num_slots,
                                     in_channels=3, 
                                     out_channels=out_channel,
                                     hid_dim=hid_dim,
                                     iters=iters,
                                     attn_drop=attn_drop,
                                     num_o=num_o,
                                     num_t=num_t,
                                     replicate=replicate,
                                     static=static)
#     model = UNet()
    
    model.to(device)
    model_without_ddp = model
    print("Model = %s" % str(model_without_ddp))
    
    if resume_path:
        print('resuming from checkpoint')
        checkpoint = torch.load(resume_path)
        model.load_state_dict(checkpoint['model_state_dict'])
#         optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
#         it = checkpoint['iteration']
#         loss = checkpoint['loss']
    else:
        print('training from scratch')

    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        model_without_ddp = model.module



    # initialize training
    criterion = nn.MSELoss()
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01, eps=1e-6)

    it = 0



    #save every eval_freq iterations
    moca = False
    monitor_train_iou = True
    log_freq = 100 #report train iou to tensorboard
#     log_freq = 10
    if args.dataset == "DAVIS": 
        eval_freq = 1e4
    elif args.dataset == "MoCA": 
        eval_freq = 1e4
        monitor_train_iou = False #this is slow due to moca evaluation
        moca = True
    elif args.dataset == "FBMS":
        eval_freq = 1e4
        monitor_train_iou = False #there is no train IoU to monitor
    elif args.dataset == "STv2":
        eval_freq = 1e4

    print('======> start training {}, {}, use {}.'.format(args.dataset, args.verbose, device))
    iou_best = 0
    frame_mean_iou = 0
    timestart = time.time()
    iter_per_epoch = int(10000 // (num_tasks * args.batch_size))
    scaler = GradScaler()
    # overfit single batch for debug
#     sample = next(iter(trn_loader))
    while it < num_it:
        if args.distributed:
            trn_loader.sampler.set_epoch(it//iter_per_epoch)
        for _, sample in enumerate(trn_loader):
#             inference / evaluate on validation set
            if it % eval_freq == 0 and it != 0:
                frame_mean_iou = eval(val_loader, model, device, moca, use_flow, it, 1, 0.3, logPath, writer=writer, train=True)

            optimizer.zero_grad()
            #'flow' has shape B, 3, 2, C, H, W 
            #'rgb' has shape B, 5, C, H, W 
            #'gt' has shape B, 3, C, H, W 
            flow, rgb, flow_idxs = sample
            rgb = rgb.float().to(device)
#             gt = gt.float().to(device)
            flow = flow.float().to(device)
            flow_idxs = flow_idxs.to(device)
            # train_iter(model, imgs, criterion)
            ## RGB B, 5, C, H, W  --model--> masks B, 3, 2, 2, (C+1), H, W 
            ###### 'recon_flow' has shape B, 7, 2, C, H, W 
            ###### 'recons' has shape B, 7, 2, 2(num_slot), C, H, W 
            ###### 'masks' has shape B, 7, 2, 2(num_slot), 1, H, W 
            recon_flow, recons, masks, slots, static = model(rgb, flow_idxs)

            ## positive consistent / negative consistent
            # slots = einops.rearrange(slots, '(b t s) n d -> b t s n d', b=batch_size, t=7)
            # slots = torch.cat((slots[:, 1:, 0], slots[:, :-1, 1]), dim=0)
            flow = einops.rearrange(flow, 'b (t s) c h w -> b t s c h w', t=7)

            recon_loss = loss_scale * criterion(flow, recon_flow)

            entropy_loss = ent_scale * -(masks * torch.log(masks + epsilon)).sum(dim=3).mean()


            masks_0 = masks[:,:,0] ## b n s c h w
            masks_1_t = einops.rearrange(masks[:,:,1], 'b n s c h w -> b n c s h w')
#             # c=1, so this is to broadcast the difference matrix
            temporal_diff = torch.pow((masks_0 - masks_1_t), 2).mean([-1, -2]) ##b n s s
            consistency_loss = cons_scale * temporal_diff.view(-1, 2 * 2).min(1)[0].mean()
            loss = recon_loss + entropy_loss + consistency_loss
            if replicate:
                recon_flow_s, recons_s, masks_s, slots_s = static 
#                 static_loss = loss_scale * criterion(torch.ones_like(recon_flow_s), recon_flow_s)
                mean_mask = masks.mean(dim=2).detach()
                temporal_diff = torch.pow((masks_s[:,:,0] - mean_mask), 2).mean([-1, -2])
                mse_loss = mse_scale * temporal_diff.view(-1, 2).min(1)[0].mean()
                entropy_loss_s = ent_scale * -(masks_s * torch.log(masks_s + epsilon)).sum(dim=3).mean()
                loss += mse_loss + entropy_loss_s
            loss.backward()
            optimizer.step()

            if it % log_freq == 0 and writer is not None:
                print('iteration {},'.format(it),
                  'time {:.01f}s,'.format(time.time() - timestart),
                  'total loss {:.02f}'.format(loss.detach().cpu().numpy()),
                  'recon loss {:.02f}.'.format(recon_loss.detach().cpu().numpy()),
                  'entropy loss {:.05f}.'.format(entropy_loss.detach().cpu().numpy()),
                  'consistency loss {:.010f}.'.format(consistency_loss.detach().cpu().numpy()))
                if replicate:
                    print(
#                       'recon loss {:.02f}.'.format(static_loss.detach().cpu().numpy()),
                      'static entropy loss {:.05f}.'.format(entropy_loss_s.detach().cpu().numpy()),
                      'dynamic-static loss {:.010f}.'.format(mse_loss.detach().cpu().numpy()))
                timestart = time.time()
            # save model
            if it % eval_freq == 0 and frame_mean_iou > iou_best:  
                filename = os.path.join(modelPath, 'checkpoint_{}_iou_{}.pth'.format(it, np.round(frame_mean_iou, 3)))
                save_on_master({
                    'iteration': it,
                    'model_state_dict': model_without_ddp.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': loss,
                    }, filename)
                iou_best = frame_mean_iou

            # LR warmup
            if it < warmup_it:
                ut.set_learning_rate(optimizer, lr * it / warmup_it)

            # LR decay
            if it % decay_step == 0 and it > 0:
                ut.set_learning_rate(optimizer, lr * (0.5 ** (it // decay_step)))
                ent_scale = ent_scale * 5.0
                cons_scale = cons_scale * 5.0
                mse_scale = mse_scale * 5.0 

            it += 1

if __name__ == "__main__":
    parser = ArgumentParser()
    #optimization
    parser.add_argument('--batch_size', type=int, default=8)
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--num_train_steps', type=int, default=3e5) #300k
    parser.add_argument('--warmup_steps', type=int, default=1e3)
    parser.add_argument('--decay_steps', type=int, default=1e5)
    parser.add_argument('--decay_rate', type=float, default=0.5)
    parser.add_argument('--loss_scale', type=float, default=100)
    parser.add_argument('--ent_scale', type=float, default=1e-2)
    parser.add_argument('--cons_scale', type=float, default=1e-2)
    parser.add_argument('--mse_scale', type=float, default=1e-1)
    #settings
    parser.add_argument('--dataset', type=str, default='DAVIS', choices=['DAVIS', 'MoCA', 'FBMS', 'STv2'])
    parser.add_argument('--with_rgb', action='store_true')
    parser.add_argument('--flow_to_rgb', action='store_true')
    parser.add_argument('--fixed_query', action='store_true')
    parser.add_argument('--temporal_cons', action='store_true')
    #architecture
    parser.add_argument('--num_slots', type=int, default=2)
    parser.add_argument('--num_iterations', type=int, default=5)
    parser.add_argument('--hidden_dim', type=int, default=32)
    parser.add_argument('--attn_drop', type=float, default=0.2)
    parser.add_argument('--num_o', type=int, default=1)
    parser.add_argument('--num_t', type=int, default=1)
    parser.add_argument('--replicate', action='store_true')
    parser.add_argument('--static', action='store_true')
    parser.add_argument('--gap', type=int, default=2, help='the sampling stride of frames')
    #misc
    parser.add_argument('--seed', default=0, type=int)
    parser.add_argument('--verbose', type=str, default=None)
    parser.add_argument('--basepath', type=str, default=None)
    parser.add_argument('--output_path', type=str, default=None)
    parser.add_argument('--resume_path', type=str, default=None)
    parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
    parser.add_argument('--local_rank', default=-1, type=int)
    parser.add_argument('--dist_on_itp', action='store_true')
    parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
    parser.add_argument('--device', default='cuda', help='device to use for training / testing')
    
    args = parser.parse_args()
    args.inference = False
    args.distributed = True
    main(args)
