"""Main script to train a model"""
import argparse
import json
from svrss.utils.functions import count_params
from svrss.learners.initializer import Initializer
from svrss.learners.model import Model
from svrss.models import TARSSNet_V1, TARSSNet_V2
import os
import torch.nn as nn
import torch
from svrss.utils.distributed_utils import init_distributed_mode, get_rank


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--cfg', help='Path to config file.', default='config.json')
    parser.add_argument('--dataset', default='Carrada_RD', help='dataset for model training')
    parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training')
    parser.add_argument("--sync-bn", dest="sync_bn", help="Use sync batch norm", action='store_true')
    parser.add_argument("--finetune", default=False, help="the finetune path of model")
    args = parser.parse_args()

    init_distributed_mode(args)
    cfg_path = args.cfg
    with open(cfg_path, 'r') as fp:
        cfg = json.load(fp)
    device = torch.device(cfg['device'])
    cfg['distributed'] = args.distributed
    # @20240327
    # cfg['dataset'] = args.dataset
    if cfg['dataset'] in ['CWRdata', 'Carrada_RA', 'Carrada_RD']:
        cfg['nb_classes'] = 4
    elif cfg['dataset'] == 'PDRdata':
        cfg['nb_classes'] = 5
    else:
        raise KeyError('Dataset {} has not been supported yet.'.format(cfg['nb_classes']))

    init = Initializer(cfg)
    data = init.get_data()
    if cfg['model'] == 'tarssnet_v1':
        net = TARSSNet_V1(n_classes=data['cfg']['nb_classes'],
                   n_frames=data['cfg']['nb_input_channels'])
    elif cfg['model'] == 'tarssnet_v2':
        net = TARSSNet_V2(n_classes=data['cfg']['nb_classes'],
                   n_frames=data['cfg']['nb_input_channels'])

    print('Number of trainable parameters in the model: %s' % str(count_params(net)))

    if args.distributed and args.sync_bn:
        net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net)
    
    if args.finetune:
        saved_model = torch.load(args.finetune, map_location=torch.device('cpu'))
        net.load_state_dict(saved_model)
    
    net.to(device)
    # # lt @20230522 固定权重生成的种子
    torch.manual_seed(cfg['torch_seed'])
    net.apply(_init_weights)

    net_without_ddp = net
    if args.distributed:
        net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[args.gpu])
        net_without_ddp = net.module
    
    
    Model(net, data).train(add_temp=True)

def _init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        nn.init.constant_(m.bias, 0.)
    elif isinstance(m, nn.Conv2d):
        torch.nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            nn.init.constant_(m.bias, 0.)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.uniform_(m.weight, 0., 1.)
            nn.init.constant_(m.bias, 0.)

# def _init_weights(m):
#     if isinstance(m, nn.Linear):
#         torch.nn.init.xavier_uniform_(m.weight)
#         nn.init.constant_(m.bias, 0.)
#     elif isinstance(m, nn.Conv2d):
#         torch.nn.init.xavier_uniform_(m.weight)
#         if m.bias is not None:
#             nn.init.constant_(m.bias, 0.)
#     elif isinstance(m, nn.Conv3d):
#         torch.nn.init.xavier_uniform_(m.weight)
#         if m.bias is not None:
#             nn.init.constant_(m.bias, 0.)
#     elif isinstance(m, nn.ConvTranspose2d):
#         torch.nn.init.xavier_uniform_(m.weight)
#         if m.bias is not None:
#             nn.init.constant_(m.bias, 0.)
#     elif isinstance(m, nn.BatchNorm2d):
#         nn.init.uniform_(m.weight, 0., 1.)
#         nn.init.constant_(m.bias, 0.)
#     elif isinstance(m, nn.BatchNorm3d):
#         nn.init.uniform_(m.weight, 0., 1.)
#         nn.init.constant_(m.bias, 0.)

if __name__ == '__main__':
    main()
