from option import args

import torch
import utility
import data
import loss
from trainer import Trainer
import warnings
warnings.filterwarnings('ignore')
import os
# os.system('pip install einops')
import model
torch.manual_seed(args.seed)
checkpoint = utility.checkpoint(args)
from data.div2k import DIV2K

import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import dist_util


def init_seed(seed=23) -> None:
    r"""init random seed for random functions in numpy, torch, cuda and cudnn

    Args:
        seed (int): random seed
        reproducibility (bool): Whether to require reproducibility
    """
    
    import random
    import numpy as np
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def main():
    global model
    ## set ddp
    dist_util.init_distributed_mode(args)
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True
    num_tasks = dist_util.get_world_size()
    global_rank = dist_util.get_rank()
    
    init_seed()

    ### set here to enable test_only
    # args.test_only = True
    if checkpoint.ok:
        #import pdb; pdb.set_trace() 
        args.train_dataset = DIV2K(args)
        loader = data.Data(args, num_tasks, global_rank)
        
        state_dict = torch.load(args.pretrain, map_location='cpu')
        
        _model = model.Model(args, checkpoint)
        _model = DDP(_model, device_ids=[args.gpu], output_device=args.gpu, find_unused_parameters=True)
        _model.load_state_dict(state_dict,strict=False)  # 读取量化权重用此行
        
        _loss = loss.Loss(args, checkpoint) if not args.test_only else None
        t = Trainer(args, loader, _model, _loss, checkpoint)

        #import pdb; pdb.set_trace()
        os.makedirs(args.save, exist_ok=True)

        if not args.test_only:
            for epoch in range(0, args.epochs):
                t.pretrain(epoch)
                if dist_util.get_rank() == 0:
                    t.test(args)
            checkpoint.done()
        elif args.test_only:
            t.test(args)
            
if __name__ == '__main__':
    main()
