import torch
import argparse
import os
import sys
import inspect
from util import parse_bool, concat_tensor_dicts
from context_dgp_functions import CONTEXT_DGPs

def main():
    currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
    parentdir = os.path.dirname(currentdir)
    sys.path.insert(0, parentdir) 
    
    from util import set_seed, make_parent_dir

    parser = argparse.ArgumentParser()
    parser.add_argument("--D", type=int, default=100,
                        help="number of domains (training)")
    parser.add_argument("--D_eval", type=int, default=1000,
                        help="number of domains (evaluation)")
    parser.add_argument("--D_extra", type=int, default=0,
                        help="number of domains (extra)")
    parser.add_argument("--N", type=int, default=50,
                        help="number of obervations per domain")
    parser.add_argument("--datadir", type=str, 
                        help="save dataset to this path")
    parser.add_argument("--one_X_per_col", type=parse_bool, default=False)
    parser.add_argument("--dimX", type=int, default=1)
    parser.add_argument("--method", type=str, default=None, choices=CONTEXT_DGPs.keys())
    parser.add_argument('--seed', type=int, default=23948452)
    parser.add_argument('--flip', type=parse_bool, default=False)
    parser.add_argument('--Z_file', type=str, default=None)
    parser.add_argument('--Z_synthetic', type=parse_bool, default=False)
    parser.add_argument('--only_dump_unused', type=parse_bool, default=False)
    parser.add_argument('--name', type=str, default='')
    parser.add_argument('--repeat_train_Z', type=int, default=1)
    args = parser.parse_args()
    print("")
    print(vars(args))
    set_seed(382374298)
    z_synth_str=''
    if args.Z_synthetic:
        assert args.Z_file is not None
        z_synth_str=',synth_Z'
    if args.Z_file is not None:
        all_Z = torch.load(args.Z_file)
        assert all_Z['generate_Z'].shape[0] >= (args.D + args.D_eval + args.D_extra)
        generate_Z = all_Z['generate_Z']
        train_Z = all_Z['train_Z']
        print(f'Z file has {len(generate_Z)} rows')
    else:
        generate_Z = None
        train_Z = None
    assert args.method in CONTEXT_DGPs.keys()
    generate_fn = CONTEXT_DGPs[args.method]

    g = torch.Generator()
    g.manual_seed(args.seed)
    flipstr = ''
    if args.flip: flipstr=',flip'

    eval_start = 0
    eval_end = args.D_eval
    train_start = eval_end
    train_end = train_start+args.D
    extra_start = train_end
    extra_end = train_end + args.D_extra
    print(f'eval start: {eval_start}, eval end: {eval_end}, train end: {train_end}')
    if args.Z_file is not None:
        all_extra_start = args.D_eval+args.D
        all_extra_end = len(train_Z)
    if args.name is None or len(args.name)==0:
        name = ''
    else:
        name = args.name + ':'
    Dstr = args.D
    D_extra_str = ''
    if args.D_extra > 0:
        D_extra_str=f'D_extra={args.D_extra},'
    if args.repeat_train_Z > 1:
        Dstr = f'{Dstr}x{args.repeat_train_Z}'
    save_dir = f"{args.datadir}/{name}N={args.N},D={Dstr},D_eval={args.D_eval},{D_extra_str}method={args.method},dimX={args.dimX},one_X_per_col={args.one_X_per_col}{flipstr}{z_synth_str}/"
    eval_fname  = save_dir + '/eval_data.pt'
    train_fname  = save_dir + '/train_data.pt'
    extra_fname = save_dir + '/extra_data.pt' 

    # TRAINING SET ------------------
    def generate_dataset(args,g,start,end,one_X_per_col,fname=None):
        if args.Z_file is not None:
            data_dict = generate_fn(D=end-start, N=args.N, dimX=args.dimX, ave_U=False, one_X_per_col=one_X_per_col, g=g, Z=generate_Z[start:end])
        else:
            data_dict = generate_fn(D=end-start, N=args.N, dimX=args.dimX, ave_U=False, one_X_per_col=one_X_per_col, g=g)
        data_dict['N'] = args.N
        data_dict['D'] = end-start
        if train_Z is not None:
            if args.Z_synthetic:
                data_dict['Z'] = generate_Z[start:end]
            else:
                data_dict['Z'] = train_Z[start:end]
                data_dict['Z_gen'] = generate_Z[start:end]
            data_dict['Z_file'] = args.Z_file
        if fname is not None:
            with open(fname, 'wb') as f:
                torch.save(data_dict, f)
                print('saving to ', fname)
        return data_dict
    
    make_parent_dir(save_dir)
    def generate_train_dataset(args,g,start,end,one_X_per_col,fname):
        if args.repeat_train_Z == 0:
            return generate_dataset(args,g,start,end,one_X_per_col,fname)
        train_Zs = []
        for _ in range(args.repeat_train_Z):
            train_Zs.append(generate_dataset(args,g,start,end,one_X_per_col,fname))
        res = concat_tensor_dicts(train_Zs)    
        with open(fname, 'wb') as f:
            torch.save(res, f)
            print('saving to ', fname)
        return res

    if args.flip:
        eval_data_dict = generate_dataset(args, g, eval_start, eval_end, args.one_X_per_col, eval_fname)
        data_dict = generate_train_dataset(args, g, train_start, train_end, args.one_X_per_col, train_fname)
    else:
        data_dict = generate_train_dataset(args, g, train_start, train_end, args.one_X_per_col, train_fname)
        eval_data_dict = generate_dataset(args, g, eval_start, eval_end, args.one_X_per_col, eval_fname)

    if args.D_extra > 0:
        extra_data_dict = generate_dataset(args, g, extra_start, extra_end, args.one_X_per_col, extra_fname)

    # dump unused Z's
    # unused as in not in train or eval, but it can overlap with the extra data dict above
    if args.Z_file is not None:
        torch.save({'Z':generate_Z[all_extra_start:all_extra_end]}, save_dir + '/unused_generate_Z.pt')
        torch.save({'Z':train_Z[all_extra_start:all_extra_end]}, save_dir + '/unused_raw_Z.pt')

    assert eval_data_dict['click_rate'][0,0] != data_dict['click_rate'][0][0]
    
    # Checks of DGP =======================================================
    # checks after generation (for seed purposes)
    # set ave_U to True
    if args.Z_file is not None:
        check_data = generate_fn(1000, 1000, args.dimX, True, False, g, generate_Z[:1000])
    else:
        check_data = generate_fn(1000, 1000, args.dimX, True, False, g)
    check_clicks = check_data['click_rate']
    mean_click_rate_over_u = check_data['click_rate_ave_U']
    check_Ys = check_data['Y']
    print("DGP Checks\n")
    
    MSE_prob_vs_Y = torch.mean(torch.square(check_clicks - check_Ys))
    best_no_context = torch.mean(mean_click_rate_over_u, axis=1, keepdim=True)
    best_everywhere_avg = torch.mean(mean_click_rate_over_u)
    MSE_best_marginal = torch.mean(torch.square(check_clicks - mean_click_rate_over_u))
    MSE_best_no_context = torch.mean(torch.square(check_clicks - best_no_context))
    MSE_best_no_Zs = torch.mean(torch.square(check_clicks - best_everywhere_avg))
    best_seq_noX = torch.mean(check_clicks, axis=1, keepdim=True)
    MSE_best_seq_noX = torch.mean(torch.square(check_clicks - best_seq_noX))
        
    print(f"MSE of best sequential model (X,Z) (comparing to true click rates)          {0:.3f}")
    print(f"MSE of best no context sequential model (Z) (comparing to true click rates) {MSE_best_seq_noX:.3f}")
    print(f"MSE of best marginal model (X,Z) (comparing to true click rates)            {MSE_best_marginal.item():.3f}")
    print("")
    print(f"MSE of no X marginal model (comparing to true click rates)                  {MSE_best_no_context.item():.3f}")
    print(f"MSE of no (Z,X) marginal model (comparing to true click rates)              {MSE_best_no_Zs.item():.3f}")
    print("")
    print(f"MSE of best sequential model (X,Z) (comparing to Ys)                        {MSE_prob_vs_Y.item():.3f}")

if __name__ == "__main__":
    main()
