import os
import torch

from torch.utils.data import DataLoader
from src.data_loader import UniversalSMIGenerator, NpzDataset
from src.utils import network_name, loss_weight_name, agent_name, dat_filename, count_network_size
from src.models import ModelSM

from torch.utils.tensorboard import SummaryWriter
import pathlib

from tqdm import tqdm

import argparse


def parse_args():
    parser = argparse.ArgumentParser()
    
    # args for train data loader.
    parser.add_argument("-d", "--distributions", type=str, default='UU', 
                        help='set distribution type for each side. [(U|D|G)+] (default: "UU")')
    parser.add_argument("-p", "--distrib_params", type=float, nargs='+', default=[0.4,0.4],
                        help='set the parameters of each distributions. (default: [0.4,0.4])')
    parser.add_argument("-N", "--num_agents", type=int, nargs='+', default=5,
                        help='set the number of agents for each side. (default: 5')
    parser.add_argument("-N_max", "--max_num_agents", type=int, nargs='+', default=[],
                        help='set the maximum number of agents for each size. If skippe. num_agents is used.')    
    parser.add_argument("-b","--batch_size", type=int, default=8,
                        help='set batch size. (default: 8)')
    parser.add_argument("-s","--steps_per_epoch",type=int, default=200,
                        help='set the number of steps per epoch. default:200')
    
    parser.add_argument("--continuous_satisfaction", action='store_true',
                        help='use continuous satisfaction value instead of discretized preference list. (default: False)')
    
    # args for val data loader
    parser.add_argument("--val_dat", type=str, default=None,
                        help='set validation dataset dat file path. If None, the script tries to identify the dat file in val_data_dir.')
    parser.add_argument("--val_data_dir", type=str, default = "./datasets/validation/",
                        help='set validation dataset location. This is used only when val_dat is None. (default: "./datasets/validation/")')


    # args for network
    parser.add_argument("-L", "--n_layers", type=int, default = 18,
                        help='set the number of feature weaving layers. (default: 18)')
    parser.add_argument("-w", "--width1", type=int, default =64,
                        help='set the width of a layer (number of channels). (default: 64)')
    parser.add_argument("-W", "--width2", type=int, default =64,
                        help='set the second width of a layer (number of channels). This option is valid only with WeaveNet. (default: 64)')
    
    parser.add_argument("-a", "--asymmetric", action='store_true',
                        help='use asymmetric adjacent edge encoders for each side. (default: False)')
    parser.add_argument("-r", "--use_resnet", action='store_true',
                        help='use residual strucuture. (default: False for L<15, True for L>=15)')
    
    parser.add_argument("--network_type", type=str, default='WeaveNet',
                       help='Select the network type from [MLP|GCN|WeaveNet]. (default: WeaveNet)')
    
    # args for training
    parser.add_argument("-e", "--num_epochs", type=int, default=1000,
                        help='set the number of epochs (default: 1000)')
    parser.add_argument("--start_epoch",type=int,default=0,
                        help='Load a trained model at the start epoch and continue the training. (default: 0)')
    
    # args for loss weights
    parser.add_argument("-lr","--learning_rate", type=float, default = 0.0001,
                        help="set learning rate. (default: 0.0001)")
    parser.add_argument("-lm","--lambda_matrix_constraint", type=float, default = 1.0,
                        help = "set weight for matrix constraint loss. (default: 1.0)")
    parser.add_argument("-ls","--lambda_unstability", type=float, default = 0.7,
                        help = "set weight for unstability loss. (default: 0.7)")
    parser.add_argument("-le","--lambda_satisfaction", type=float, default = 0.0,
                        help = "set weight for satisfaction loss. (default: 0.0)")
    parser.add_argument("-lf","--lambda_fairness", type=float, default = 0.0,
                        help = "set weight for fairness loss. (default: 0.0)")
    parser.add_argument("-lb","--lambda_balance", type=float, default = 0.0,
                        help = "set weight for balance loss. This cannot be set to >0 with ls or lf. (default: 0.0). ")

    parser.add_argument('--constraint_p', dest='constraint_p', type=float, default=2.,
                        help='Value of constraint_p. If <0, use a Eucledean-distance-based matrix constraint.')
    
    # args for output
    parser.add_argument("--checkpoint_root", type=str, default="./checkpoints",
                       help="set the root directory where checkpoint models will be stored. (default: ./checkpoints)")
    parser.add_argument("--name", type=str, default=None,
                        help='set experiment name used to generate checkpoint and log file names. If None, the script automatically generate the name. (default: None)')
    parser.add_argument("-cp", "--checkpoint", type=int, default=200,
                        help="save the trained model parameter at every {checkpoint} epoch. (default:200)")
    parser.add_argument("-v", "--validation", type=int, default=1,
                        help="validate the model at every {validation} epochs. (default: 10)")
    parser.add_argument("-t","--tensorboard_dir",type=str,default='./tensorboard_log',
                        help="set the directory for tensorboard log.")
    
    # args for acceleration
    parser.add_argument("--num_workers", type=int, default=1,
                        help="set num_workers to data loaders. If set to -1, use maximum number of cpus. (default: 1)")
    parser.add_argument("--pin_memory", action='store_true',
                        help="set the pin_memory flag to data loaders.")
    parser.add_argument("--cudnn_benchmark",action='store_true',
                       help="set the cudnn benchmark flag.")
    parser.add_argument("--non_blocking", action='store_true',
                       help="set the non_blocking flag.")
    parser.add_argument("--device", type=int, default=-1,
                        help="set device. If -1, automatically select cpu or gpu. If -2, force to use cpu (default: None)")

    return parser.parse_args()

def get_exp_name(args):
    mnum_agents = args.max_num_agents    
    if len(mnum_agents)<len(args.num_agents):
        mnum_agents += args.num_agents[len(mnum_agents):]
    N_ranges = [(n,n_max) for n,n_max in zip(args.num_agents,mnum_agents)]
    return os.path.join(
        network_name(args.network_type,args.n_layers,args.width1,args.width2,args.asymmetric,args.use_resnet),        
        loss_weight_name(
            args.learning_rate,
            args.lambda_matrix_constraint,
            args.lambda_unstability,
            args.lambda_satisfaction,
            args.lambda_fairness,
            args.lambda_balance,
            args.constraint_p),
        agent_name(args.distributions,N_ranges[0][0],N_ranges[1][0],N_ranges[0][1],N_ranges[1][1])
    )

def get_model_filename(epoch):
    return '{:05d}_net.pth'.format(epoch)

def main():
    args = parse_args()
    if args.n_layers>=15:
        args.use_resnet = True
    #for k,v in args.__dict__.items():
    #    print(k,": ",v)
    
    # fixed args.
    K = len(args.num_agents)
    if K>2:
        raise RuntimeError('this program does not support K-partite matching with K>2.')
        
    mnum_agents = args.max_num_agents
    if len(mnum_agents)<K:
        mnum_agents += args.num_agents[len(mnum_agents):]
    
    N_ranges = [(n,n_max) for n,n_max in zip(args.num_agents,mnum_agents)]
    distrib_m = args.distributions[0]
    distrib_f = args.distributions[-1]
    
    if args.val_dat is None:
        dat_path=os.path.join(args.val_data_dir,dat_filename(distrib_m,distrib_f,N_ranges[0][0],'validation'))
    else:
        dat_path=args.val_dat
        
                                      
           
    
    dataset_train = UniversalSMIGenerator(
        distrib_m, distrib_f,
        N_range_m=N_ranges[0],
        N_range_w=N_ranges[-1],
        sigma_m=args.distrib_params[0],
        sigma_f=args.distrib_params[-1],
        batch_size=args.batch_size,
        len=args.steps_per_epoch*args.batch_size,
        transform = not args.continuous_satisfaction,
    )
    num_workers = args.num_workers
    if num_workers<=0:
        num_workers = os.cpu_count()
    dl_train = DataLoader(dataset_train, batch_size=args.batch_size,num_workers=num_workers,pin_memory=args.pin_memory)   
    
    print("validation set: ",dat_path)
    dataset = NpzDataset(dat_path,args.val_data_dir)
    dl_val = DataLoader(dataset,batch_size=1,shuffle=False,num_workers=num_workers,pin_memory=args.pin_memory) 
    #print(len(dl_val))
    
    
    if args.device==-1:
        device = 'cpu'
        if torch.cuda.is_available():
            device = 'cuda'
    elif args.device<-1:
        device = 'cpu'
    else:
        device = 'cuda:{}'.format(args.device)
    device = torch.device(device)
        
    exp_name = get_exp_name(args)
        
        
    sab,sba,_,_=next(iter(dl_train))
    
    model = ModelSM(
        network_type=args.network_type,
        sab_shape=sab.shape,
        L = args.n_layers,
        w1 = args.width1,
        w2 = args.width2,
        asymmetric = args.asymmetric,
        use_resnet = args.use_resnet,
        lr=args.learning_rate,
        lambda_m=args.lambda_matrix_constraint,
        lambda_u=args.lambda_unstability,
        lambda_s=args.lambda_satisfaction,
        lambda_f=args.lambda_fairness,
        lambda_b=args.lambda_balance,
        constraint_p=args.constraint_p,
        device = device,
        checkpoints_dir = os.path.join(args.checkpoint_root,exp_name)        
    )
    tensorboard_log_dir = os.path.join(args.tensorboard_dir,exp_name)
    pathlib.Path(tensorboard_log_dir).mkdir(parents=True, exist_ok=True)
    writer = SummaryWriter()
    csv_file = os.path.join(args.checkpoint_root,exp_name+'.csv')
    cp_dir=os.path.join(args.checkpoint_root,exp_name)
    pathlib.Path(cp_dir).mkdir(parents=True, exist_ok=True)    
    n_params = count_network_size(model.net)
    with open("{}.model.txt".format(cp_dir),'w') as f:
        f.write("# of network parameters: {}\n".format(n_params))
        f.write(str(model.net))
    print("# of network parameters: ",n_params)

    if args.start_epoch > 0:       
        load_filename = get_model_filename(args.start_epoch)
        model.load_network(load_filename)             

    model.net.to(device)
    
    train(dl_train, dl_val, model, writer, csv_file,
          device,
          args.num_epochs, args.checkpoint, args.validation,
          start_epoch=args.start_epoch
          )

def train(dl_train, dl_val, model, writer, csv_file,
          device,
          num_epochs: int,
          checkpoint: int,
          validation: int,
          cudnn_benchmark=False,
          non_blocking=False,
          start_epoch=0,
         ):        
    
    assert(start_epoch>=0)
    
    model.net.train()
    torch.backends.cudnn.benchmark = cudnn_benchmark
    
    if start_epoch==0:
        with open(csv_file,'w') as f:
            f.write("epochs,L_m(tr), L_u(tr), L_f(tr), L_b(tr), L_m(val), L_u(val), L_f(val), L_b(val), Acc, SEq, Bal\n")        
        
    for e in range(start_epoch,num_epochs):
        print("Epoch [{:05d}/{:05d}]".format(e+1,num_epochs))
        t = tqdm(dl_train)
        l_m_run = 0
        l_u_run = 0
        l_f_run = 0
        l_b_run = 0
        num_samples = 0
        
        for i, batch in enumerate(t):            
            sab, sba, na, nb = [b.to(device,non_blocking=non_blocking) for b in batch]
            ss =[sab,sba]
            num_samples += sab.shape[0]
            _, l_m, l_u, _, l_f, l_b, = model.train(ss,[na,nb])
                        
            l_m_run += l_m.detach().cpu()
            l_u_run += l_u.detach().cpu()
            l_f_run += l_f.detach().cpu()
            l_b_run += l_b.detach().cpu()
            
            t.set_postfix(
                L_m=float(l_m_run/num_samples),
                L_u=float(l_u_run/num_samples),
                L_f=float(l_f_run/num_samples),
                L_b=float(l_b_run/num_samples))

        
        
        if dl_val is not None and ((e+1)%validation==0 or e==num_epochs-1):
            # save train losses
            L_m = l_m_run/num_samples
            L_u = l_u_run/num_samples
            L_f = l_f_run/num_samples
            L_b = l_b_run/num_samples
            line = [e+1,L_m,L_u,L_f,L_b]
            writer.add_scalar('L_m(tr)',L_m)
            writer.add_scalar('L_u(tr)',L_u)
            writer.add_scalar('L_f(tr)',L_f)
            writer.add_scalar('L_b(tr)',L_b)
            
            # calc val losses and costs.            
            model.net.eval() 
            L_m, L_u, L_f, L_b, acc, SEq, Bal = model.validate(dl_val,device)
            
            # save val losses and costs
            writer.add_scalar('L_m(val)',L_m)
            writer.add_scalar('L_u(val)',L_u)
            writer.add_scalar('L_f(val)',L_f)
            writer.add_scalar('L_b(val)',L_b)
            writer.add_scalar('Acc',acc)
            writer.add_scalar('SEq',SEq)
            writer.add_scalar('Bal',Bal)
            line+= [L_m,L_u,L_f, L_b, acc, SEq, Bal]
            
            with open(csv_file,'a') as f:
                f.write(",".join([str(float(l)) for l in line]))
                f.write("\n")
            print("Validation")
            print("L_m(tr), L_u(tr), L_f(tr), L_b(tr), L_m(val), L_u(val), L_f(val), L_b(val), Acc, SEq, Bal")
            print(",".join([str(float(l)) for l in line[1:]]))
            model.net.train()

            
        if (e+1)%checkpoint==0 or e==num_epochs-1:            
            save_filename = get_model_filename(e+1)
            print("save network parameters to: ",os.path.join(model.save_dir,save_filename))
            model.save_network(save_filename)
            
            
    
if __name__ == '__main__':
    main()
    
