import os
import argparse
import pathlib
import time
import random
from einops import rearrange
import torch.distributed as dist

import math
import numpy as np
import copy
from datetime import datetime
import torch
import torch.nn as nn
import torch.nn.functional as F
from fvcore.common.checkpoint import Checkpointer
from datasets import create_dataloader
from models import *
from utils import *


def load_config():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', type=str, default ='/home/peiyao/Documents/ACode/Contras_refine/Contras_fea/configs/assembly101/100.yaml')
    parser.add_argument('--resume', type=str, default='')
    parser.add_argument('--local_rank', type=int, default=0)
    parser.add_argument('options', default=None, nargs=argparse.REMAINDER)
    args = parser.parse_args()

    config = get_default_config_moredata_aux_stage1()

    if args.config is not None:
        config.merge_from_file(args.config)
    config.merge_from_list(args.options)
    if not torch.cuda.is_available():
        config.device = 'cpu'
    if args.resume != '':
        config_path = pathlib.Path(args.resume) / 'config.yaml'
        config.merge_from_file(config_path.as_posix())
        config.merge_from_list(['train.resume', True])
    world_size = world_info_from_env()
    config.merge_from_list(['train.dist.local_rank', args.local_rank, 'world_size', world_size]) ## understand each line here !!!
    print(f'==========world_size: {world_size}')
    config = update_config(config)
    config.freeze()
    return config



def train_one_epoch(epoch, config, model, loss_func, optimizer, scheduler, train_loader, logger):
    # global global_step#, init_latent
    # logger.info(f'Train {epoch} {global_step}')
    if epoch == 1:
        print('----------------------------- Start Training ! --------------------------')
    device = torch.device(config.device)
    model.train()
    epoch_sum, min_val_loss = 0, 100000
    num_steps = len(train_loader)
    for idx, item in enumerate(train_loader):

        # Getting view_c and view_h embedding

        samples_c = rearrange(item[0].to(device), 'b n c l -> (b n) l c ')  # -> [20,100, 2048](b,l,c)
        samples_h = rearrange(item[1].to(device), 'b n c l -> (b n) l c ')
        win_mask = rearrange(item[-1].to(device), 'b n l -> (b n) l')

        out_c = model(samples_c, mask=win_mask)  # input:[b,l,c]
        out_h = model(samples_h, mask=win_mask)  # input:[b,l,c]

        # print(f'out_c shape:{out_c.shape}')
        # print(f'out_h shape:{out_h.shape}')

        # Calculating the Loss
        assert out_c['feature'].shape[0] == out_h['feature'].shape[0], 'Output of the c/h view feature mismatch batch number!'
        loss = loss_func(out_c['feature'], out_h['feature'], out_c['logit_scale']) #  c_features, h_features, logit_scale,

        ##----------aux loss, !! may the logit_scale & loss_weight can be change
        aux_out_c = out_c['aux_fea'] # list 4
        aux_out_h = out_h['aux_fea']
        for i in range(len(aux_out_c)):
            loss += config.loss.aux_weight * loss_func(aux_out_c[i], aux_out_h[i], out_h['logit_scale'])
        ##----------

        epoch_sum += loss

        optimizer.zero_grad()  ## each batch, each update
        loss.backward()
        optimizer.step()
        scheduler.step_update((epoch * num_steps + idx))

        if config.train.distributed:
            epoch_sum_all_reduce = dist.all_reduce(epoch_sum, op=dist.ReduceOp.SUM, async_op=True)
            epoch_sum_all_reduce.wait()
            epoch_sum.div_(dist.get_world_size())
        epoch_sum = epoch_sum.item()

        if torch.cuda.is_available():
            torch.cuda.synchronize()

        if (idx+1) % config.train.log_period == 0 or (idx+1) == len(train_loader):
            # print('{} | Epoch {} ({}) loss: {}'.format(datetime.now(), epoch, nums, epoch_sum / nums))
            if get_rank() == 0:
                lr = optimizer.param_groups[0]['lr']

                logger.info(
                            # f'{datetime.now()} |'
                            f'Train Epoch {epoch} ({idx+1}/{len(train_loader)}) |'
                            f'Lr: {lr}|'
                            # f'Learning rate {epoch} ({nums}) |'
                            f'Avg loss is {(epoch_sum / (idx+1)):.8f}')

def validate(epoch, config, model, loss_func, val_loader, min_val_loss, logger):
    device = torch.device(config.device)
    model.eval()
    nums = 0
    loss_sum = 0

    with torch.no_grad():
        for idx, item in enumerate(val_loader):
            nums += 1
            # Getting view_c and view_h embedding
            samples_c = rearrange(item[0].to(device), 'b n c l -> (b n) l c ')  # -> [20,100, 2048](b,l,c)
            samples_h = rearrange(item[1].to(device), 'b n c l -> (b n) l c ')
            out_c = model(samples_c)  # input:[b,l,c]
            out_h = model(samples_h)  # input:[b,l,c]

            # Calculating the Loss
            loss = loss_func(out_c['feature'], out_h['feature'], out_c['logit_scale'])
            loss_sum += loss
            if config.train.distributed:
                loss_sum_all_reduce = dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM, async_op=True)
                loss_sum_all_reduce.wait()
                loss_sum.div_(dist.get_world_size())
            loss_sum = loss_sum.item()

        if torch.cuda.is_available():
            torch.cuda.synchronize()

        if get_rank() == 0:
            avg_loss = loss_sum / (len(val_loader))
            lr = optimizer.param_groups[0]['lr']
            is_best = (avg_loss < min_val_loss)
            min_val_loss = min(avg_loss, min_val_loss)

            logger.info(f'------------------------|'
                        f'Val Epoch {epoch} |'
                        f'Lr: {lr}|'
                        f'Avg loss is {avg_loss:.8f} |' 
                        f'Min loss is {min_val_loss:.8f}')


def main():

    config = load_config()
    set_seed(config)
    setup_cudnn(config)

    epoch_seeds = np.random.randint(np.iinfo(np.int32).max // 2, size=config.train.epochs)

    if config.train.distributed:
        dist.init_process_group(backend=config.train.dist.backend,
                                init_method=config.train.dist.init_method,
                                rank=config.train.dist.node_rank,
                                world_size=config.train.dist.world_size)
        torch.cuda.set_device(config.train.dist.local_rank)

    out_dir = os.path.join(config.train.output_dir, 'assembly_charades_0922', 'pair_'+config.model.note)
    # out_dir = os.path.join(config.train.output_dir, 'assembly_charades_0922', 'pair',config.model.name, config.model.note)

    output_dir = pathlib.Path(out_dir)

    output_dir.mkdir(exist_ok=True, parents=True)
    if not config.train.resume:
        save_config(config, output_dir / 'configs.yaml')
        save_config(get_env_info(config), output_dir / 'env.yaml')

    log_file = f'{config.model.note}.txt'
    logger = create_logger(name=__name__,
                           distributed_rank=get_rank(),
                           output_dir=output_dir,
                           filename=log_file)
    logger.info(config)
    logger.info(get_env_info(config))
    logger.info(f'---------------model will save in: {out_dir}.')

    train_loader, val_loader = create_dataloader(config.dataset, is_train=True)

    ##-----------build model-------
    model = Refine_Fea_DoubleView_aux(config.model).cuda(config.train.dist.local_rank)
    # model = test( 2048, 128)

    print('==================================================================')
    model = apply_data_parallel_wrapper(config, model) #DDP(model)

    optimizer = build_optimizer(config, model)
    scheduler = build_scheduler(config, optimizer, n_iter_per_epoch = len(train_loader))
    print(f'---------------------------Dataset size: {len(train_loader)}')

    train_loss_func = create_loss_gather(config)
    val_loss_func = create_loss(config)

    checkpointer = Checkpointer(model,
                              optimizer=optimizer,
                              scheduler=scheduler,
                              save_dir=output_dir,
                              save_to_disk=get_rank() == 0)

    start_epoch = config.train.start_epoch
    scheduler.last_epoch = start_epoch

    if config.train.resume:
        checkpoint_config = checkpointer.resume_or_load('', resume=True)
        global_step = checkpoint_config['global_step']
        start_epoch = checkpoint_config['epoch']
        config.defrost()
        config.merge_from_other_cfg(ConfigNode(checkpoint_config['configs']))
        config.freeze()
    elif config.train.checkpoint != '':
        checkpoint = torch.load(config.train.checkpoint, map_location='cpu')
        if isinstance(model,
                      (nn.DataParallel, nn.parallel.DistributedDataParallel)):
            model.module.load_state_dict(checkpoint['model'])
        else:
            model.load_state_dict(checkpoint['model'])

    min_val_loss = 1000000
    # current_best_epoch = validate(0, config, model, val_loader, min_val_loss, logger)

    for epoch, seed in enumerate(epoch_seeds[start_epoch:], start_epoch):
        # print('epoch_seeds = ',epoch_seeds)
        epoch +=1

        # np.random.seed(seed)
        train_one_epoch(epoch, config, model, train_loss_func, optimizer, scheduler, train_loader, logger)

        # if epoch % 100 == 0:
        #     validate(epoch, config, model, val_loss_func, val_loader, min_val_loss, logger)
        #
        checkpoint_config = {
                            'epoch': epoch,
                            'configs': config.as_dict(),
                            }

        if  epoch % 50 == 0 and get_rank() == 0:
            checkpointer.save(f'checkpoint_epoch_{epoch}', **checkpoint_config)
            print('----------Save DONE!')

        # if get_rank() == 0 and epoch in config.train.save_epochs:
        #     checkpointer.save(f'checkpoint_epoch_{epoch}', **checkpoint_config)
        #     print('----------Save DONE!')


if __name__ =='__main__':
    main()