from header import * 

import os 

from utils import train, sample_massive_batch, sample_massive_batch_tree

try: 
    disable_tqdm = os.environ['DISABLE_TQDM'] 
except KeyError: 
    disable_tqdm = False 

def main(args): 
    torch.manual_seed(args.seed) 
    # Train the GFlowNet iteratively 
    num_steps = args.epochs // args.epochs_per_step + 1 

    save_dir = f'{args.domain}/criteria/' 
    pathlib.Path(save_dir).mkdir(exist_ok=True, parents=True) 
    
    if args.domain in ['grid', 'sequences'] and args.criterion == 'db': 
        args.criterion = 'dbc'

    # Instantiate the model 
    forward_flow, backward_flow, log_reward = create_gfn(client=42, **create_env_args, **flow_args)
    state_flow = create_state_flow(**create_env_args, **flow_args) 
    gflownet = GFlowNet(forward_flow, backward_flow, state_flow=state_flow, off_policy_rate=.5, criterion=args.criterion) 
   
    # Instantiate the optimizer 
    if args.criterion == 'tb': 
        optimizer = torch.optim.Adam([
            {'params': chain(gflownet.forward_flow.parameters(), gflownet.backward_flow.parameters()), 'lr': 1e-3}, 
            {'params': gflownet.log_partition_function, 'lr': 1e-1} 
        ], lr=args.lr) 
    else: 
        optimizer = None 

    hrewards = list() 
    tv = list() 

    initial_step = True 

    for step in range(num_steps + 1): 
        gflownet.train() 
        # Train the model 
        create_env_p = partial(create_env, batch_size=args.batch_size_train, log_reward=log_reward, **create_env_args) 
        if initial_step: 
            initial_step = False  
        else: 
            train(gflownet, epochs=args.epochs_per_step, create_env=create_env_p, optimizer=optimizer, lr=args.lr) 

        # Sample from the currently learned distribution 
        gflownet.eval() 
        create_env_p = partial(create_env, batch_size=args.batch_size_eval, log_reward=log_reward, **create_env_args) 
        if args.domain == 'phylogenetics': 
            samples, rewards, newick = sample_massive_batch_tree(gflownet, create_env=create_env_p, num_batches=args.num_batches_eval, disable_tqdm=disable_tqdm) 
            _, indices, counts = np.unique(newick, return_index=True, return_counts=True) 
        else: 
            samples, rewards = sample_massive_batch(gflownet, create_env=create_env_p, num_batches=args.num_batches_eval) 
            # Compute the total variation distance 
            indices, counts = unique_smp(samples) 
        counts = torch.tensor(counts) 

        tv.append( 
            (counts / counts.sum() - rewards[indices].exp() / rewards[indices].exp().sum()).abs().sum().item() 
        ) 
        # Evaluate the most frequently sampled objects (due to the sizeableaness of the considered domains, 
        # an accurate estimate of any form of distributional discrepancy is presumably infeasible) 
        topk = rewards.topk(int(1e2)).values.mean() 
        hrewards.append(topk.item()) 

    # Evaluate (consolidate) the distributional adequacy at the current stage 
    json.dump({
        'xs': (np.arange(num_steps + 1) * args.epochs_per_step).tolist(), 
        'ys': hrewards, 
        'tv': tv
    }, open(f'{save_dir}/plots_{args.seed}_{args.criterion}.json', 'w')) 

    # Return the necessary values for plotting 
    # return xs, ys 

if __name__ == '__main__': 
    main(args) 

# device = 'cuda' if torch.cuda.is_available() else 'cpu'
# torch.set_default_device(device)
# torch.set_default_dtype(torch.float64)

# criteria = ['db', 'tb', 'cb']

# set_size = 32
# warehouse_size = 15
# batch_size = 512

# emb_dim = min(warehouse_size, 32)
# hidden_dim = 128

# seeds = [42, 84, 128]
# epochs_per_step = int(2e1)
# steps = int(3e2)

# log_reward = LogRewardLinear(21, warehouse_size, device=device)

# rewards_per_criterion = dict()
# for seed in seeds:
#     torch.manual_seed(seed)
#     for criterion in criteria:
#         print(criterion)
#         if criterion not in rewards_per_criterion:
#             rewards_per_criterion[criterion] = dict()

#         rewards_per_criterion[criterion][seed] = list()

#         forward_flow = ForwardFlow(emb_dim=emb_dim, hidden_dim=hidden_dim, warehouse_size=warehouse_size)
#         backward_flow = BackwardFlow()
#         state_flow = StateFlow(emb_dim=emb_dim, hidden_dim=hidden_dim, warehouse_size=warehouse_size)
#         gflownet = GFlowNet(forward_flow, backward_flow, state_flow=state_flow, off_policy_rate=.5, criterion=criterion)

#         if criterion == 'tb':
#             optimizer = torch.optim.Adam([
#                 {'params': gflownet.forward_flow.parameters(), 'lr': 1e-3},
#                 {'params': gflownet.backward_flow.parameters(), 'lr': 1e-3},
#                 {'params': gflownet.log_partition_function, 'lr': 1e-1}
#             ])
#         else:
#             optimizer = None
        
#         create_env = lambda: Set(set_size, warehouse_size, batch_size=int(1e4), log_reward=log_reward) 
#         env, rewards = sample_massive_batch(gflownet, create_env=create_env, num_batches=int(1e2)) 
#         rewards_per_criterion[criterion][seed].append(rewards.topk(int(1e2)).values.tolist()) 

#         for step in range(steps):
#             create_env = lambda: Set(set_size, warehouse_size, batch_size, log_reward=log_reward)
#             train(gflownet, epochs=epochs_per_step, create_env=create_env, optimizer=optimizer, lr=3e-3)

#             create_env = lambda: Set(set_size, warehouse_size, batch_size=int(1e4), log_reward=log_reward)
#             env, rewards = sample_massive_batch(gflownet, create_env, num_batches=int(1e2))

#             rewards_per_criterion[criterion][seed].append(rewards.topk(int(1e2)).values.tolist())

# for criterion in criteria: 
#     values_for_criterion = torch.ones((steps + 1, len(seeds))) 
#     for i, seed in enumerate(seeds): 
#         values = torch.tensor([torch.tensor(el).mean() for el in rewards_per_criterion[criterion][seed]]) 
#         values_for_criterion[:, i] = values 
#     mu = values_for_criterion.mean(dim=1) 
#     std = values_for_criterion.std(dim=1) 

#     xs = (torch.arange(steps + 1)) * epochs_per_step 
    
#     torch.save(
#         xs, f'xs_{criterion}.pt' 
#     )
#     torch.save(
#         mu, f'mu_{criterion}.pt' 
#     )
#     torch.save(
#         std, f'std_{criterion}.pt' 
#     )

# #     plt.plot(xs.cpu(), mu.cpu()) 
# #     plt.fill_between(xs.cpu(), y1=mu.cpu()-std.cpu(), y2=mu.cpu()+std.cpu(), alpha=.5, label=criterion_to_label[criterion])  

# #     plt.ylabel('Rewards (top-5 avg.)') 
# #     plt.xlabel('Epochs') 

# # plt.legend() 
# # plt.savefig('criteria.pdf') 
