# -*- coding: utf-8 -*-

import os
import sys

import statistics
import argparse
import torch
import tqdm
from tensorboardX import SummaryWriter
from torch.utils.data import DataLoader, DistributedSampler

import opencood.hypes_yaml.yaml_utils as yaml_utils
from opencood.tools import train_utils
from opencood.tools import multi_gpu_utils
from opencood.data_utils.datasets import build_dataset
from ignite.metrics import IoU
from ignite.metrics.confusion_matrix import ConfusionMatrix
import wandb

def train_parser():
    parser = argparse.ArgumentParser(description="synthetic data generation")
    parser.add_argument("--hypes_yaml", default='',
                        type=str, required=False, help='data generation yaml file needed ')
    parser.add_argument('--model_dir', default='',
                        help='Continued training path')
    parser.add_argument("--half", default=True, action='store_true',
                        help="whether train with half precision.")
    parser.add_argument('--dist_url', default='env://',
                        help='url used to set up distributed training')
    parser.add_argument('--log_tool', default='wandb', # tensorboard, wandb
                        help='use which tool to log, tensorboard or wandb')
    opt = parser.parse_args()
    return opt

if __name__ == '__main__':
    opt = train_parser()
    hypes = yaml_utils.load_yaml(opt.hypes_yaml, opt)
    hypes.update({'half_percision': opt.half})  # TODO: Just for log
    dataset_name = hypes['dataset']
    multi_gpu_utils.init_distributed_mode(opt)

    print('-----------------Dataset Building------------------')
    opencood_train_dataset = build_dataset(hypes, visualize=False, train=True)
    opencood_validate_dataset = build_dataset(hypes,
                                              visualize=False,
                                              train=False)

    if opt.distributed:
        sampler_train = DistributedSampler(opencood_train_dataset, shuffle=False)
        sampler_val = DistributedSampler(opencood_validate_dataset,
                                         shuffle=False)

        batch_sampler_train = torch.utils.data.BatchSampler(
            sampler_train, hypes['train_params']['batch_size'], drop_last=True)

        train_loader = DataLoader(opencood_train_dataset,
                                  batch_sampler=batch_sampler_train,
                                  num_workers=8,
                                  collate_fn=opencood_train_dataset.collate_batch_train)
        val_loader = DataLoader(opencood_validate_dataset,
                                sampler=sampler_val,
                                num_workers=8,
                                collate_fn=opencood_train_dataset.collate_batch_train,
                                drop_last=False)
    else:
        train_loader = DataLoader(opencood_train_dataset,
                                  batch_size=hypes['train_params'][
                                      'batch_size'],
                                  num_workers=16,
                                  collate_fn=opencood_train_dataset.collate_batch_train,
                                  shuffle=True,
                                  pin_memory=False,
                                  drop_last=True)
        val_loader = DataLoader(opencood_validate_dataset,
                                batch_size=hypes['train_params']['batch_size'],
                                num_workers=16,
                                collate_fn=opencood_train_dataset.collate_batch_train,
                                shuffle=False,
                                pin_memory=False,
                                drop_last=True)

    print('---------------Creating Model------------------')
    hypes['model']['args']['rsu_num'] = opencood_train_dataset.rsu_num
    model = train_utils.create_model(hypes)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # if we want to train from last checkpoint.
    if opt.model_dir:
        saved_path = opt.model_dir
        init_epoch, model = train_utils.load_saved_model(saved_path,
                                                         model)
    else:
        init_epoch = 0
        # if we train the model from scratch, we need to create a folder
        # to save the model,
        saved_path = train_utils.setup_train(hypes)

    # we assume gpu is necessary
    if torch.cuda.is_available():
        model.to(device)
    model_without_ddp = model

    if opt.distributed:
        model = \
            torch.nn.parallel.DistributedDataParallel(model,
                                                      device_ids=[opt.gpu],
                                                      find_unused_parameters=True)
        model_without_ddp = model.module

    # define the loss
    criterion = train_utils.create_loss(hypes)
    # optimizer setup
    optimizer = train_utils.setup_optimizer(hypes, model_without_ddp)
    # lr scheduler setup
    num_steps = len(train_loader)
    scheduler = train_utils.setup_lr_schedular(hypes, optimizer, num_steps)

    # record training
    if opt.log_tool == 'tensorboard':
        writer = SummaryWriter(saved_path)
    elif opt.log_tool == 'wandb':
        wandb.init(
            project=hypes['dataset'] + '.' + hypes['project_name'],
            name=saved_path.split('/')[-1],
            id=saved_path.split('/')[-1],
            resume='allow',
            config=hypes
        )
        wandb.define_metric("step")
        wandb.define_metric("epoch")

        wandb.define_metric("Regression_loss", step_metric="step")
        wandb.define_metric("Classification_loss", step_metric="step")
        wandb.define_metric("Validation_Loss", step_metric="epoch")
        wandb.define_metric("Matched_ious", step_metric="epoch")

    # half precision training
    if opt.half:
        scaler = torch.amp.GradScaler(str(device))

    # Creating evaluator
    if dataset_name == 'V2XReal':
        cm = ConfusionMatrix(num_classes=2) # TODO: 2 classes of detection and 4 classes segmentation
        iou_metric = IoU(cm)
    elif dataset_name == 'V2XSim':
        cm = ConfusionMatrix(num_classes=4)
        iou_metric = IoU(cm)

    print('Training start')
    epoches = hypes['train_params']['epoches']
    # used to help schedule learning rate

    for epoch in range(init_epoch, max(epoches, init_epoch)):

        for param_group in optimizer.param_groups:
            print('learning rate %.7f' % param_group["lr"])

        if opt.distributed:
            sampler_train.set_epoch(epoch)

        pbar2 = tqdm.tqdm(total=len(train_loader), leave=True)

        for i, batch_data in enumerate(train_loader):

            model.train()
            model.zero_grad()
            optimizer.zero_grad()

            batch_data = train_utils.to_device(batch_data, device)

            if dataset_name=='V2XReal':
                if not opt.half:
                    output_dict = model(batch_data['ego'])
                    final_loss = criterion(output_dict,
                                           batch_data['ego']['label_dict'])
                else:
                    with torch.amp.autocast(str(device)):
                        output_dict = model(batch_data['ego'])
                        final_loss = criterion(output_dict,
                                               batch_data['ego']['label_dict'])

                if opt.log_tool == 'tensorboard':
                    criterion.logging(epoch, i, len(train_loader), writer, pbar=pbar2)
                elif opt.log_tool == 'wandb':
                    total_loss = criterion.loss_dict['total_loss']
                    reg_loss = criterion.loss_dict['reg_loss']
                    cls_loss = criterion.loss_dict['cls_loss']


                    cmt_loss = criterion.loss_dict['cmt_loss']
                    if cmt_loss is not None:
                        wandb.log({'step': epoch * num_steps + i,
                                   'Regression_loss': reg_loss.item(),
                                   'Classification_loss': cls_loss.item(),
                                   'Commit_loss': cmt_loss.item()})
                    else:
                        wandb.log({'step': epoch * num_steps + i,
                                   'Regression_loss': reg_loss.item(),
                                   'Classification_loss': cls_loss.item()})

                    pbar2.set_description("[epoch %d][%d/%d], || Loss: %.4f || Conf Loss: %.4f"
                                         " || Loc Loss: %.4f" % (
                                             epoch, i + 1, num_steps,
                                             total_loss.item(), cls_loss.item(), reg_loss.item()))
            elif dataset_name == 'V2XSim':
                if not opt.half:
                    output_dict = model(batch_data['ego'])
                    final_loss = criterion(output_dict,
                                           batch_data['ego']['seg_label'])
                else:
                    with torch.amp.autocast(str(device)):
                        output_dict = model(batch_data['ego'])
                        final_loss = criterion(output_dict,
                                               batch_data['ego']['seg_label'])

                if opt.log_tool == 'tensorboard':
                    criterion.logging(epoch, i, len(train_loader), writer, pbar=pbar2)
                elif opt.log_tool == 'wandb':
                    seg_loss = criterion.loss_dict['seg_loss']

                    wandb.log({'step': epoch * num_steps + i,
                               'Segmentation_loss': seg_loss.item()})
                    pbar2.set_description("[epoch %d][%d/%d], || Seg Loss: %.4f" % (
                        epoch, i + 1, num_steps, seg_loss.item()))

            pbar2.update(1)

            if not opt.half:
                final_loss.backward()
                optimizer.step()
            else:
                scaler.scale(final_loss).backward()
                scaler.step(optimizer)
                scaler.update()

            scheduler.step_update(epoch * num_steps + i)

        if epoch % hypes['train_params']['save_freq'] == 0:
            torch.save(model_without_ddp.state_dict(),
                       os.path.join(saved_path,
                                    'net_epoch%d.pth' % (epoch + 1)))

        if epoch % hypes['train_params']['eval_freq'] == 0:
            valid_ave_loss = []
            match_ious = []

            with torch.no_grad():
                iou_metric.reset()
                for i, batch_data in enumerate(val_loader):

                    model.eval()

                    batch_data = train_utils.to_device(batch_data, device)
                    output_dict = model(batch_data['ego'])

                    if dataset_name == 'V2XReal':
                        final_loss = criterion(output_dict, batch_data['ego']['label_dict'])
                        iou_metric.update((output_dict['cls'], batch_data['ego']['label_dict']['label_map'][:,0,...].to(int)))
                    elif dataset_name == 'V2XSim':
                        final_loss = criterion(output_dict, batch_data['ego']['seg_label'])
                        iou_metric.update((output_dict['seg'], batch_data['ego']['seg_label'].to(int)))

                    valid_ave_loss.append(final_loss.item())

            valid_ave_loss = statistics.mean(valid_ave_loss)
            match_ious = statistics.mean(iou_metric.compute().tolist()[1:])

            print('At epoch %d, the validation loss is %f, matched ious is %f' % (epoch, valid_ave_loss, match_ious))
            if opt.log_tool == 'tensorboard':
                writer.add_scalar('Validate_Loss', valid_ave_loss, epoch)
                writer.add_scalar('Matched_ious', match_ious, epoch)
            elif opt.log_tool == 'wandb':
                wandb.log({'epoch': epoch,
                           'Validation_Loss': valid_ave_loss,
                           'Matched_ious': match_ious})
        opencood_train_dataset.reinitialize()

    print('Training Finished, checkpoints saved to %s' % saved_path)
    sys.exit(0)