import itertools
import os
import wandb
import json
import argparse
from copy import copy
from transformers import DataCollatorForLanguageModeling
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset, DatasetDict
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import json

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data.dataloader import DataLoader
from torch.nn import CrossEntropyLoss
from torch.optim import AdamW
import re
import torch
import random
import wandb

from transformers import get_scheduler, AutoTokenizer, AutoModelForCausalLM, AutoConfig

from tqdm import tqdm
from collections import Counter 
from pathlib import Path

import string
from model_utils import get_model
from graph_data_utils import get_graph_dataset, get_graph_to_save 
from phonebook_data_utils import get_phonebook_dataset, get_phonebook_to_save 

from train_utils import train#, save_model
from torch.utils.data import DataLoader



def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def set_seed(seed: int = 42) -> None:
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ["PYTHONHASHSEED"] = str(seed)
    print(f"Random seed set as {seed}")
 

def parse_args():
    parser = argparse.ArgumentParser()
 
    #model
    parser.add_argument('--model', choices=['dense','gmlp'], default='dense')
     
     
    parser.add_argument('--vocab_size', default=26, type=int)
    parser.add_argument('--num_nodes', default=60, type=int)
    parser.add_argument('--p_edge', default=0.095, type=float)
   

    parser.add_argument('--num_examples_train', default=int(2e4), type=int)
    parser.add_argument('--num_examples_val', default=1000, type=int)
    parser.add_argument('--num_examples_test', default=1000, type=int)
    

    parser.add_argument('--sequence_length', default=1800, type=int)
    parser.add_argument('--label_sequence_length', default=20, type=int)
    parser.add_argument('--seed', default=42, type=int)
    
    ## directed
    parser.add_argument('--directed', default=1, type=int)

    ## task
    parser.add_argument('--task', choices = ["phone","graph"], default="graph", type=str)



    

    return parser.parse_args()

args = parse_args()

print(args,flush=True)

set_seed(args.seed)


config_wandb = {}
config_wandb['task'] = args.task
config_wandb['model'] = args.model
config_wandb['nodes'] = int(args.num_nodes)
config_wandb['p'] = args.p_edge
config_wandb['num_train'] = args.num_examples_train
config_wandb['seed'] = args.seed

wandb_name=""
for key in config_wandb.keys():
   wandb_name += "{}_{}_".format(key,config_wandb[key])

if not os.path.exists("./datasets/"):
    os.makedirs("./datasets/")


dir_save = "./datasets/"+args.task+"/"+wandb_name

train_path = os.path.join(dir_save,"train.json")
test_path = os.path.join(dir_save,"test.json")
 
if args.task == "graph":
    train_data, val_data, test_data =  get_graph_to_save(args) 
elif args.task == "phone":
    train_data, test_data = get_phonebook_to_save(args) 
 
Path(dir_save).mkdir(parents=True, exist_ok=True)


with open(train_path, 'w') as f:
    json.dump(train_data, f)

print(f"saved {train_path}",flush=True)

if args.task == "graph":
    val_path = os.path.join(dir_save,"val.json")

    with open(val_path, 'w') as f:
        json.dump(val_data, f)

    print(f"saved {val_path}",flush=True)


with open(test_path, 'w') as f:
    json.dump(test_data, f)

print(f"saved {test_path}",flush=True)
 
