# -*- coding: utf-8 -*-

# License: TDG-Attribution-NonCommercial-NoDistrib


import argparse
import os
import statistics

import torch
import tqdm
from tensorboardX import SummaryWriter
from torch.utils.data import DataLoader, DistributedSampler
import sys; sys.path.append(os.getcwd())
import opencood.hypes_yaml.yaml_utils as yaml_utils
from opencood.tools import train_utils
from opencood.tools import multi_gpu_utils
from opencood.data_utils.datasets import build_dataset
from opencood.tools import train_utils
from opencood.tools.pytorch_mem_utils import MemTracker
from opencood.visualization.cppc_vis import VisUtil

# set batch size =1 before doing vis during train
train_vis_flag = False
train_debug_flag = False

def train_parser():
    parser = argparse.ArgumentParser(description="synthetic data generation")
    parser.add_argument("--hypes_yaml", type=str, required=True,
                        help='data generation yaml file needed ')
    parser.add_argument('--model_dir', default='',
                        help='Continued training path')
    parser.add_argument("--half", action='store_true',
                        help="whether train with half precision.")
    parser.add_argument('--dist_url', default='env://',
                        help='url used to set up distributed training')
    parser.add_argument('--debug', action='store_true',
                        help="open debug mode(gpu mem tracjer...)")
    parser.add_argument('--vis', action='store_true',
                        help="open vis mode")
    opt = parser.parse_args()
    return opt


def main():
    opt = train_parser()
    if opt.model_dir and opt.model_dir.endswith(".pth"):
        model_pth = opt.model_dir.split("/")[-1]
        opt.model_dir = "/".join(opt.model_dir.split("/")[:-1])
    else:
        model_pth = None
    hypes = yaml_utils.load_yaml(opt.hypes_yaml, opt)
    
    is_fsd = 'fsd' in hypes['model']['core_method'] 

    multi_gpu_utils.init_distributed_mode(opt)

    print('-----------------Dataset Building------------------')
    opencood_train_dataset = build_dataset(hypes, visualize=False, train=True)
    opencood_validate_dataset = build_dataset(hypes, visualize=False, train=False)
    print(f"{len(opencood_train_dataset)} train samples found.")

    if opt.distributed:
        sampler_train = DistributedSampler(opencood_train_dataset)
        sampler_val = DistributedSampler(opencood_validate_dataset,
                                         shuffle=False)

        batch_sampler_train = torch.utils.data.BatchSampler(
            sampler_train, hypes['train_params']['batch_size'], drop_last=True)

        train_loader = DataLoader(opencood_train_dataset,
                                  batch_sampler=batch_sampler_train,
                                  num_workers=4,
                                  collate_fn=opencood_train_dataset.collate_batch_train)
        val_loader = DataLoader(opencood_validate_dataset,
                                sampler=sampler_val,
                                num_workers=4,
                                collate_fn=opencood_train_dataset.collate_batch_train,
                                drop_last=False)
    else:
        train_batch_size = 1 if opt.vis else hypes['train_params']['batch_size']
        train_loader = DataLoader(opencood_train_dataset,
                                  batch_size=train_batch_size if not train_debug_flag else 2,
                                  num_workers=8 if not train_debug_flag else 0,
                                  collate_fn=opencood_train_dataset.collate_batch_train,
                                  shuffle=True if not (train_debug_flag or opt.vis) else False,
                                  pin_memory=True,
                                  drop_last=True)
        val_loader = DataLoader(opencood_validate_dataset,
                                batch_size=1,
                                num_workers=1,
                                collate_fn=opencood_train_dataset.collate_batch_train,
                                shuffle=False,
                                pin_memory=True,
                                drop_last=True)

    print('---------------Creating Model------------------')
    model = train_utils.create_model(hypes)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # if we want to train from last checkpoint.
    if opt.model_dir:
        saved_path = opt.model_dir
        init_epoch, model = train_utils.load_saved_model(
            saved_path, 
            model,
            model_pth=model_pth,
            seg_pretrain=hypes['train_params'].get('seg_pretrain', False)
        )
    
    elif hypes['model']['args'].get('segmentor_model', None) is not None:
        load_path = hypes['model']['args']['segmentor_model']
        # load first stage model
        segmentor_state_dict = torch.load(
            load_path,
            map_location='cpu'
        )
        segmentor_state_dict = {
            k.replace('detector.', ''): v
            for k,v in segmentor_state_dict.items() 
        }
        for k, v in segmentor_state_dict.items():
            replace_key = 'backbone'
            if (replace_key in k and 'weight' in k):
                dims = [len(v.shape) - 1] + list(range(len(v.shape) - 1))
                segmentor_state_dict[k] = v.permute(*dims)
        model.detector.segmentor.load_state_dict(segmentor_state_dict)
        init_epoch = 0
        saved_path = train_utils.setup_train(hypes)
    else:
        init_epoch = 0
        # if we train the model from scratch, we need to create a folder
        # to save the model,
        saved_path = train_utils.setup_train(hypes)
    
    if opt.vis and not os.path.exists(os.path.join(saved_path, 'train_vis')):
        os.mkdir(os.path.join(saved_path, 'train_vis'))
    VisUtil.set_args(
        training=True,
        root_path=os.path.join(saved_path, 'train_vis'),
        is_vis=opt.vis,
        save_vis_n=1 if train_vis_flag else -1,
    )

    # we assume gpu is necessary
    if torch.cuda.is_available():
        model.to(device)
    model_without_ddp = model

    if opt.distributed:
        model = \
            torch.nn.parallel.DistributedDataParallel(model,
                                                      device_ids=[opt.gpu],
                                                      find_unused_parameters=True)
        model_without_ddp = model.module

    # define the loss
    if not is_fsd:
        criterion = train_utils.create_loss(hypes)

    # optimizer setup
    optimizer = train_utils.setup_optimizer(hypes, model_without_ddp)
    # lr scheduler setup
    num_steps = len(train_loader)
    scheduler = train_utils.setup_lr_schedular(hypes, optimizer, num_steps)

    # record training
    writer = SummaryWriter(saved_path)

    # half precision training
    if opt.half:
        scaler = torch.cuda.amp.GradScaler()

    print('Training start')
    epoches = hypes['train_params']['epoches']
    # used to help schedule learning rate
    if opt.debug:
        mem_tracker = MemTracker()
        mem_tracker.create_track_thread()

    for epoch in range(init_epoch, max(epoches, init_epoch)):
        if hypes['lr_scheduler']['core_method'] != 'cosineannealwarm':
            scheduler.step(epoch)
        if hypes['lr_scheduler']['core_method'] == 'cosineannealwarm':
            scheduler.step_update(epoch * num_steps + 0)
        for param_group in optimizer.param_groups:
            print('learning rate %.7f' % param_group["lr"])

        if opt.distributed:
            sampler_train.set_epoch(epoch)

        if not is_fsd:
            pbar2 = tqdm.tqdm(total=len(train_loader), leave=True)

        for i, batch_data in enumerate(train_loader):
            if opt.debug:
                mem_tracker.record_epoch(epoch, i, train=True)
            zero_gt = False
            for tmp in batch_data['ego']['object_bbx_mask']:
                if tmp.sum() == 0:
                    zero_gt = True
            if zero_gt:
                continue
            # the model will be evaluation mode during validation
            model.train()
            model.zero_grad()
            optimizer.zero_grad()

            batch_data = train_utils.to_device(batch_data, device) 
            batch_data['ego']['metas'] = {'model_dir': saved_path, 'vis_dir': os.path.join(saved_path, 'train_vis'),
                'batch_idx': i, 'vis_n': 1, 'vis_type': 'bev', 'proj_first': opencood_train_dataset.proj_first}

            # case1 : late fusion train --> only ego needed,
            # and ego is random selected
            # case2 : early fusion train --> all data projected to ego
            # case3 : intermediate fusion --> ['ego']['processed_lidar']
            # becomes a list, which containing all data from other cavs
            # as well
            if not opt.half:
                output_dict = model(batch_data['ego'])
                if not is_fsd:
                    # first argument is always your output dictionary,
                    # second argument is always your label dictionary.
                    final_loss = criterion(output_dict,
                                           batch_data['ego']['label_dict'])
                else:
                    final_loss = output_dict['loss']
            else:
                with torch.cuda.amp.autocast():
                    output_dict = model(batch_data['ego'])
                    if not is_fsd:
                        final_loss = criterion(output_dict,
                                               batch_data['ego']['label_dict'])
                    else:
                        final_loss = output_dict['loss']

            if not is_fsd:
                criterion.logging(epoch, i, len(train_loader), writer, pbar=pbar2)
                pbar2.update(1)
            else:
                loss_log = ','.join([f"{k}: {v:.4f}" for k, v in output_dict['log_vars'].items()])
                print(f"[epoch {epoch}][{i + 1}/{len(train_loader)}], {loss_log}")
                for k, v in output_dict['log_vars'].items():
                    writer.add_scalar(k, v, epoch*len(train_loader) + i)

            if not opt.half:
                final_loss.backward()
                optimizer.step()
            else:
                scaler.scale(final_loss).backward()
                scaler.step(optimizer)
                scaler.update()

            if hypes['lr_scheduler']['core_method'] == 'cosineannealwarm':
                scheduler.step_update(epoch * num_steps + i)

        if epoch % hypes['train_params']['save_freq'] == 0:
            torch.save(model_without_ddp.state_dict(),
                os.path.join(saved_path, 'net_epoch%d.pth' % (epoch + 1)))

        if epoch % hypes['train_params']['eval_freq'] == 0:
            valid_ave_loss = []

            # with torch.no_grad():
            #     for i, batch_data in enumerate(val_loader):
            #         if opt.debug:
            #             mem_tracker.record_epoch(epoch, i, train=False)
            #         model.eval()

            #         batch_data = train_utils.to_device(batch_data, device)
            #         batch_data['ego']['metas'] = {'model_dir': saved_path, 'vis_dir': os.path.join(saved_path, 'train_vis'),
            #             'batch_idx': i, 'vis_n': 1, 'vis_type': 'bev', 'proj_first': opencood_validate_dataset.proj_first}
            #         output_dict = model(batch_data['ego'])

            #         if not is_fsd:
            #             final_loss = criterion(output_dict,
            #                                    batch_data['ego']['label_dict'])
            #         else:
            #             final_loss = output_dict['loss']
            #         valid_ave_loss.append(final_loss.item())
            # valid_ave_loss = statistics.mean(valid_ave_loss)
            # print('At epoch %d, the validation loss is %f' % (epoch,
            #                                                   valid_ave_loss))
            # writer.add_scalar('Validate_Loss', valid_ave_loss, epoch)

    print('Training Finished, checkpoints saved to %s' % saved_path)
    if opt.debug:
        mem_tracker.end_track()


if __name__ == '__main__':
    main()
