import os
import argparse
import pathlib
import time
import random
from einops import rearrange
import torch.distributed as dist

import math
import numpy as np
import copy
import datetime
import torch
import torch.nn as nn
import torch.nn.functional as F
from fvcore.common.checkpoint import Checkpointer
from datasets import create_dataloader
from models import *
from utils import *


def load_config():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', type=str, default ='/home/peiyao/Documents/ACode/Contras_refine/Contras_fea/configs/assembly101/100.yaml')
    parser.add_argument('--resume', type=str, default='')
    parser.add_argument('--local_rank', type=int, default=0)

    parser.add_argument('options', default=None, nargs=argparse.REMAINDER)
    args = parser.parse_args()

    config = get_default_config_view_sigmoid_aux_stage2() # different from the pair-wise cfg

    if args.config is not None:
        config.merge_from_file(args.config)
    config.merge_from_list(args.options)
    if not torch.cuda.is_available():
        config.device = 'cpu'
    if args.resume != '':
        config_path = pathlib.Path(args.resume) / 'config.yaml'
        config.merge_from_file(config_path.as_posix())
        config.merge_from_list(['train.resume', True])
    world_size = world_info_from_env()
    config.merge_from_list(['train.dist.local_rank', args.local_rank, 'world_size', world_size]) ## understand each line here !!!
    print(f'==========world_size: {world_size}')
    config = update_config(config)
    config.freeze()
    return config


def train_one_epoch(epoch, config, model, loss_func, optimizer, scheduler, train_loader, logger):
    if epoch == 1:
        print('----------------------------- Start Training ! --------------------------')
    device = torch.device(config.device)
    model.train()
    epoch_sum, min_val_loss = 0, 100000
    num_steps = len(train_loader)
    for idx, item in enumerate(train_loader):

        # Getting view_c and view_h embedding
        # Getting view_c and view_h embedding
        assert item[0].shape[2] == 8, "view3 is not enough!"
        assert item[1].shape[2] == 4, "view1 is not enough!"
        samples_c = rearrange(item[0].to(device), 'b g n l c -> (b g n) l c')  # [20,100, 2048]
        samples_h = rearrange(item[1].to(device), 'b g n l c -> (b g n) l c')

        cb_n_view, cwin, cc = samples_c.shape
        hb_n_view, hwin, hc = samples_h.shape
        if cb_n_view % 8!= 0 or hb_n_view % 4 != 0:
            print('Ignore the batch of sample because of dataparallel and model!')
            continue

        out_c = model(samples_c, n_view=8)  # input:[b,l,c]
        out_h = model(samples_h, n_view=4)  # input:[b,l,c]
        assert out_c['feature'].shape[0] == out_h['feature'].shape[0], 'Output of the c/h view feature mismatch batch number!'

        fea_dim = out_c['feature'].shape[-1]

        loss = loss_func(out_c['feature'].view(-1,fea_dim), out_h['feature'].view(-1,fea_dim ), out_h['logit_scale'])

        ##----------aux loss, !! may the logit_scale & loss_weight can be change
        aux_out_c = out_c['aux_fea'] # list 4
        aux_out_h = out_h['aux_fea']
        for i in range(len(aux_out_c)):
            loss += config.loss.aux_weight * loss_func(aux_out_c[i].view(-1,fea_dim), aux_out_h[i].view(-1,fea_dim ), out_h['logit_scale'])
        ##----------

        epoch_sum += loss

        optimizer.zero_grad()  ## each batch, each update
        loss.backward() # each epoch for backwards
        optimizer.step()
        scheduler.step_update((epoch * num_steps + idx))

        if config.train.distributed:
            epoch_sum_all_reduce = dist.all_reduce(epoch_sum, op=dist.ReduceOp.SUM, async_op=True)
            epoch_sum_all_reduce.wait()
            epoch_sum.div_(dist.get_world_size())
        epoch_sum = epoch_sum.item()

        if torch.cuda.is_available():
            torch.cuda.synchronize()

        if (idx+1) % config.train.log_period == 0 or (idx+1) == len(train_loader):
            # print('{} | Epoch {} ({}) loss: {}'.format(datetime.now(), epoch, nums, epoch_sum / nums))
            if get_rank() == 0:
                lr = optimizer.param_groups[0]['lr']
                logger.info(
                            # f'{datetime.now()} |'
                            f'Train Epoch {epoch} ({idx+1}/{len(train_loader)}) |'
                            f'Lr: {lr}|'
                            # f'Learning rate {epoch} ({nums}) |'
                            f'Avg loss is {(epoch_sum / (idx+1)):.8f}')

def validate(epoch, config, model, loss_func, val_loader, min_val_loss, logger):
    device = torch.device(config.device)
    model.eval()
    nums = 0
    loss_sum = 0

    with torch.no_grad():
        for idx, item in enumerate(val_loader):
            nums += 1
            # Getting view_c and view_h embedding
            assert item[0].shape[2] == 8, "view3 is not enough!"
            assert item[1].shape[2] == 4, "view1 is not enough!"
            samples_c = rearrange(item[0].to(device), 'b g n l c -> (b g n) l c')  # [20,100, 2048]
            samples_h = rearrange(item[1].to(device), 'b g n l c -> (b g n) l c')

            cb_n_view, cwin, cc = samples_c.shape
            hb_n_view, hwin, hc = samples_h.shape
            if cb_n_view % (8) !=0 or hb_n_view % (4) !=0:
                print('Ignore the batch of sample because of dataparallel and model!')
                continue

            out_c = model(samples_c, n_view=8)  # input:[b,l,c]
            out_h = model(samples_h, n_view=4)  # input:[b,l,c]


            # Calculating the Loss
            loss = loss_func(out_c['feature'], out_h['feature'], out_c['logit_scale'])
            loss_sum += loss
            if config.train.distributed:
                loss_sum_all_reduce = dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM, async_op=True)
                loss_sum_all_reduce.wait()
                loss_sum.div_(dist.get_world_size())
            loss_sum = loss_sum.item()

        if torch.cuda.is_available():
            torch.cuda.synchronize()

        if get_rank() == 0:
            avg_loss = loss_sum / (len(val_loader))

            is_best = (avg_loss < min_val_loss)
            min_val_loss = min(avg_loss, min_val_loss)

            logger.info(f'------------------------|'
                        f'Val Epoch {epoch} |'
                        f'Avg loss is {avg_loss:.8f} |' 
                        f'Min loss is {min_val_loss:.8f}')


def main():

    config = load_config()
    set_seed(config)
    setup_cudnn(config)

    epoch_seeds = np.random.randint(np.iinfo(np.int32).max // 2, size=config.train.epochs)

    if config.train.distributed:
        dist.init_process_group(backend=config.train.dist.backend,
                                timeout=datetime.timedelta(seconds=7200000),
                                init_method=config.train.dist.init_method,
                                rank=config.train.dist.node_rank,
                                world_size=config.train.dist.world_size)
        torch.cuda.set_device(config.train.dist.local_rank)

    ##-----------out root-------
    ## ...0922_stage1_main_group_pair_norm_aug_moredata_ddp/0922_lr504_logit/checkpoint_epoch_170.pth

    stage1_path_weight = config.model.refine_fea.pair_weight_ckpt_path
    stage1_path = stage1_path_weight.rsplit('/', 1)[0] # ...0922_stage1_main_group_pair_norm_aug_moredata_ddp/0922_lr504_logit
    stage1_ckpt_epoch = stage1_path_weight.rsplit('/', 1)[-1] # checkpoint_epoch_170.pth
    stage1_epoch =  stage1_ckpt_epoch.replace('.pth','').rsplit('_', 1)[-1]
    stage2_note = f'multi_{config.model.note}_1stage{stage1_epoch}'


    if config.model.refine_fea.use_pair_weight:
        out_dir = os.path.join(stage1_path, stage2_note)
    else:
        out_dir = os.path.join(stage1_path, 'random')
    ##--------------------------

    print(f'Model will save in: {out_dir} ')

    output_dir = pathlib.Path(out_dir)

    output_dir.mkdir(exist_ok=True, parents=True)
    if not config.train.resume:
        save_config(config, output_dir / 'configs.yaml')
        save_config(get_env_info(config), output_dir / 'env.yaml')
        # diff = find_config_diff(config)
        # if diff is not None:
        #     save_config(diff, output_dir / 'config_min.yaml')

    log_file = f'{config.model.note}.txt'
    logger = create_logger(name=__name__,
                           distributed_rank=get_rank(),
                           output_dir= output_dir,
                           filename=log_file)
    logger.info(config)
    logger.info(get_env_info(config))

    train_loader, val_loader = create_dataloader(config.dataset, is_train=True)

    ##-----------build model-------
    model = build_model(config,
                        Refine_Fea_MultiView_sigmoid_aux(config.model))

    model = apply_data_parallel_wrapper(config, model) #DDP(model)

    optimizer = build_optimizer(config, model)
    scheduler = build_scheduler(config, optimizer, n_iter_per_epoch = len(train_loader))

    ##--------different loss with or wihout feather gather-------------
    train_loss_func = create_multi_loss(config, 'SigLipLoss')
    val_loss_func = create_loss(config)

    ##---------------------checkpoint--------------------
    checkpointer = Checkpointer(model,
                              optimizer=optimizer,
                              scheduler=scheduler,
                              save_dir=output_dir,
                              save_to_disk=get_rank() == 0)

    start_epoch = config.train.start_epoch
    scheduler.last_epoch = start_epoch

    if config.train.resume:
        checkpoint_config = checkpointer.resume_or_load('', resume=True)
        global_step = checkpoint_config['global_step']
        start_epoch = checkpoint_config['epoch']
        config.defrost()
        config.merge_from_other_cfg(ConfigNode(checkpoint_config['configs']))
        config.freeze()
    elif config.train.checkpoint != '':
        checkpoint = torch.load(config.train.checkpoint, map_location='cpu')
        if isinstance(model,
                      (nn.DataParallel, nn.parallel.DistributedDataParallel)):
            model.module.load_state_dict(checkpoint['model'])
        else:
            model.load_state_dict(checkpoint['model'])

    min_val_loss = 1000000
    # current_best_epoch = validate(0, config, model, val_loader, min_val_loss, logger)

    for epoch, seed in enumerate(epoch_seeds[start_epoch:], start_epoch):
        # print('epoch_seeds = ',epoch_seeds)
        epoch +=1

        # np.random.seed(seed)
        train_one_epoch(epoch, config, model, train_loss_func, optimizer, scheduler, train_loader, logger)


        checkpoint_config = {
                            'epoch': epoch,
                            'configs': config.as_dict(),
                            }

        if  epoch % 50 == 0 and get_rank() == 0:
            checkpointer.save(f'checkpoint_epoch_{epoch}', **checkpoint_config)
            print('----------Save DONE!')

        # if get_rank() == 0 and epoch in config.train.save_epochs:
        #     checkpointer.save(f'checkpoint_epoch_{epoch}', **checkpoint_config)
        #     print('----------Save DONE!')

if __name__ =='__main__':
    main()