# Copyright 2021 Zhongyang Zhang
# Contact: mirakuruyoo@gmai.com
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# 
#     http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

""" This main entrance of the whole project.

    Most of the code should not be changed, please directly
    add all the input arguments of your model's constructor
    and the dataset file's constructor. The MInterface and 
    DInterface can be seen as transparent to all your args.    
"""
import warnings
import os
import pytorch_lightning as pl
from argparse import ArgumentParser
from pytorch_lightning import Trainer
import pytorch_lightning.callbacks as plc
from pytorch_lightning.loggers import TensorBoardLogger

from model import MInterface
from data import DInterface
from utils import load_model_path_by_args
from torch.utils.tensorboard import SummaryWriter
warnings.filterwarnings("ignore")
import torch


class EMA(pl.Callback):
    def __init__(self, decay=0.999):
        super().__init__()
        self.decay = decay
        self.shadow = {}

    def on_train_start(self, trainer, pl_module):
        for name, param in pl_module.named_parameters():
            if param.requires_grad:
                self.shadow[name] = param.data.clone()

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
        with torch.no_grad():
            for name, param in pl_module.named_parameters():
                if param.requires_grad:
                    assert name in self.shadow
                    self.shadow[name] -= (1 - self.decay) * (self.shadow[name] - param.data)

    def apply_shadow(self, pl_module):
        for name, param in pl_module.named_parameters():
            if param.requires_grad:
                assert name in self.shadow
                param.data.copy_(self.shadow[name])


def main(args):
    data_module = DInterface(**vars(args))
    load_path = args.checkpoint_path
    log_dir = "/home/star/Data/g1/yxh/logs"

    if load_path is None:
        model = MInterface(**vars(args))
        model = model.to(args.gpu)
    else:
        model = MInterface(**vars(args))
        model = MInterface.load_from_checkpoint(load_path,strict=False,kwargs=args)
        model = model.to(args.gpu)
        
        args.ckpt_path = load_path

    trainer = Trainer.from_argparse_args(args,devices=[3],accelerator='cuda',default_root_dir=log_dir)

    trainer.test(model, datamodule=data_module)


if __name__ == '__main__':
    parser = ArgumentParser()
    # Basic Training Control
    parser.add_argument('--batch_size', default=50, type=int)
    parser.add_argument('--num_workers', default=16, type=int)
    parser.add_argument('--seed', default=3907, type=int)
    parser.add_argument('--lr', default=0.0005, type=float)
    parser.add_argument('--gpu', default='cuda:3', type=str)
    parser.add_argument('--error_log_dir', default='/home/n611/Projects/b611/lightningOcean/information/log/ocean/errorSC.csv', type=str)
    # LR Scheduler
    parser.add_argument('--lr_scheduler', choices=['step', 'cosine'], type=str)
    parser.add_argument('--lr_decay_steps', default=20, type=int)
    parser.add_argument('--lr_decay_rate', default=0.5, type=float)
    parser.add_argument('--lr_decay_min_lr', default=1e-5, type=float)
    # parser.add_argument('--gradient_clip_val', default=0.5, type=float)

    # Restart Control
    parser.add_argument('--load_best', type=bool,default=False)
    parser.add_argument('--load_path',  default=None,type=str)
    parser.add_argument('--load_dir', type=str)
    parser.add_argument('--load_ver', default='version_16', type=str)
    parser.add_argument('--load_v_num', default=None, type=int)

    # Training Info
    parser.add_argument('--dataset', default='standard_data', type=str)
    parser.add_argument('--csv_file', default='./3dshapeWhole.csv', type=str,help='train csv') 
    parser.add_argument('--val_dir', default='./3dshapeWhole.csv', type=str,help='test csv')
    parser.add_argument('--model_name', default='CDQAE', type=str)
    parser.add_argument('--loss', default='ce', type=str)
    parser.add_argument('--weight_decay', default=1e-5, type=float)
    parser.add_argument('--no_augment', action='store_true')
    parser.add_argument('--log_dir', default='CDQAE', type=str)

    
    # Model Hyperparameters
    parser.add_argument('--hidden_size', default=512, type=int)
    parser.add_argument('--moving_average_decay', default=0.99, type=float)
    parser.add_argument('--nmf_lambda', default=0.3, type=float)
    parser.add_argument('--orthogonal_lambda', default=0.3, type=float)
    parser.add_argument('--similarity_lambda', default=0.7, type=float)
    parser.add_argument('--rank', default=6, type=int)
    
    # Other
    parser.add_argument('--aug_prob', default=0.5, type=float)
    parser.add_argument('--cls', default='object_color', type=str)
    # Checkpoint path for testing
    parser.add_argument('--checkpoint_path', default='',type=str)
    parser.add_argument('--npz_path', default='',type=str)
    parser.set_defaults(max_epochs=1)

    args = parser.parse_args()
    
    # List Arguments
    args.mean_sen = [0.5]
    args.std_sen = [0.5]

    main(args)
