import itertools
import os
import wandb
import json
import argparse
from copy import copy
from transformers import DataCollatorForLanguageModeling
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset, DatasetDict

import numpy as np
import matplotlib.pyplot as plt

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data.dataloader import DataLoader
from torch.nn import CrossEntropyLoss
from torch.optim import AdamW
import re
import torch
import random
import wandb

from transformers import get_scheduler, AutoTokenizer, AutoModelForCausalLM, AutoConfig

from tqdm import tqdm
from collections import Counter 
from pathlib import Path

import string
from model_utils import get_model
from graph_data_utils import get_graph_dataset,Dataset,get_graph_tokenizer
from phonebook_data_utils import get_phonebook_dataset,get_phonebook_tokenizer

from train_utils import train#, save_model
from torch.utils.data import DataLoader



def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def set_seed(seed: int = 42) -> None:
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ["PYTHONHASHSEED"] = str(seed)
    print(f"Random seed set as {seed}")


def parse_args():
    parser = argparse.ArgumentParser()

    #model
    parser.add_argument('--model', choices=['dense','sparse','gmlp','attention_only'], default='sparse')
    parser.add_argument('--hidden_size',default=512,type=int)
    parser.add_argument('--layers',default=6,type=int)
    parser.add_argument('--heads',default=8,type=int)
    parser.add_argument('--num_experts',default=8,type=int)
    parser.add_argument('--router_aux_loss_coef',default=0.02,type=float)
    

    #epochs
    parser.add_argument('--learning_rate', default=1e-4, type=float)
    parser.add_argument('--num_epochs', default=1, type=int)

    parser.add_argument('--train_batch_size', default=8, type=int)
    parser.add_argument('--eval_batch_size', default=10, type=int)
    parser.add_argument('--gradient_accumulation_steps', default=1, type=int)
    parser.add_argument('--weight_decay', default=0.1, type=float)
    
    parser.add_argument('--vocab_size', default=26, type=int)
    parser.add_argument('--num_nodes', default=15, type=int)
    parser.add_argument('--p_edge', default=0.25, type=float)
   

    parser.add_argument('--num_examples_train', default=int(2e4), type=int)
    parser.add_argument('--num_examples_val', default=1000, type=int)
    parser.add_argument('--num_examples_test', default=1000, type=int)
    

    parser.add_argument('--sequence_length', default=500, type=int)
    parser.add_argument('--label_sequence_length', default=8, type=int)
    parser.add_argument('--seed', default=42, type=int)
    
    ## directed
    parser.add_argument('--directed', default=1, type=int)

    ## task
    parser.add_argument('--task', choices = ["phone","graph"], default="phone", type=str)


    

    return parser.parse_args()

args = parse_args()

print(args)

set_seed(args.seed)



##load dataset
config_dataset = {}
config_dataset['task'] = args.task
config_dataset['model'] = "dense" if args.model in ["dense","sparse","attention_only"] else args.model
config_dataset['nodes'] = int(args.num_nodes)
config_dataset['p'] = args.p_edge
config_dataset['num_train'] = args.num_examples_train
config_dataset['seed'] = args.seed

dataset_name=""
for key in config_dataset.keys():
   dataset_name += "{}_{}_".format(key,config_dataset[key])

dataset_path = os.path.join("./datasets/"+args.task+"/",dataset_name)

train_path = os.path.join(dataset_path,"train.json")
test_path = os.path.join(dataset_path,"test.json")

assert os.path.exists(train_path)
assert os.path.exists(test_path)


with open(train_path) as f:
    train_data = json.load(f)


for idx in range(len(train_data['input_ids'])):
    train_data['input_ids'][idx] = torch.tensor(train_data['input_ids'][idx], dtype=torch.int64)
    if args.model == "gmlp":
        train_data['label_ids'][idx] = torch.tensor(train_data['label_ids'][idx])
    else:
        train_data['mask'][idx] = torch.tensor(train_data['mask'][idx])


if args.task == "graph":
    val_path = os.path.join(dataset_path,"val.json")
    assert os.path.exists(val_path)

    with open(val_path) as f:
        val_data = json.load(f)

    for idx in range(len(val_data['input_ids'])):
        val_data['input_ids'][idx] = torch.tensor(val_data['input_ids'][idx], dtype=torch.int64)
        if args.model == "gmlp":
            val_data['label_ids'][idx] = torch.tensor(val_data['label_ids'][idx])
        else:
            val_data['mask'][idx] = torch.tensor(val_data['mask'][idx])
    val_dataset = Dataset(val_data, args.model)
    val_dataloader =  DataLoader(val_dataset, batch_size=args.eval_batch_size, shuffle=False)

    tokenizer, TO_TOKEN, TO_CHAR =  get_graph_tokenizer(args)
    
else:
    val_dataloader = None
    tokenizer, TO_TOKEN, TO_CHAR =  get_phonebook_tokenizer(args)


with open(test_path) as f:
    test_data = json.load(f)
 
for idx in range(len(test_data['input_ids'])):
    test_data['input_ids'][idx] = torch.tensor(test_data['input_ids'][idx], dtype=torch.int64)
    if args.model == "gmlp":
        test_data['label_ids'][idx] = torch.tensor(test_data['label_ids'][idx])
    else:
        test_data['mask'][idx] = torch.tensor(test_data['mask'][idx])



train_dataset = Dataset(train_data, args.model)
test_dataset = Dataset(test_data, args.model)
 

train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True)
 

test_dataloader =  DataLoader(test_dataset, batch_size=args.eval_batch_size, shuffle=False)
batch = next(iter(train_dataloader)) 



print("-"*100,flush=True)
print(f"EXAMPLE {batch['input'][0]}",flush=True)
print("-"*100,flush=True)
if args.model != "gmlp":
    print(batch['input_ids'][-1][batch['mask'][-1]==1], batch['input_ids'][-1], batch['input'][-1],flush=True)
print("^"*100,flush=True)
 




## Get model
model = get_model(args, tokenizer)
num_parameters = count_parameters(model)

print("^"*100)
print(model)
print(f"NUM PARAMS OF MODEL {num_parameters}")
print("^"*100)

 

## setup wandb
config_wandb = {}
config_wandb['model'] = args.model
config_wandb['hs'] = args.hidden_size
config_wandb['layers'] = args.layers
config_wandb['heads'] = args.heads
if args.model == "sparse":
   config_wandb['experts'] = args.num_experts
config_wandb['nodes'] = int(args.num_nodes)
config_wandb['p'] = args.p_edge
config_wandb['num_train'] = args.num_examples_train
config_wandb['lr'] = args.learning_rate
config_wandb['bs'] = args.train_batch_size
config_wandb['ga'] = args.gradient_accumulation_steps
config_wandb['wd'] = args.weight_decay
config_wandb['epochs'] = args.num_epochs
config_wandb['task'] = args.task
config_wandb['seed'] = args.seed


wandb_name=""
for key in config_wandb.keys():
   wandb_name += "{}_{}_".format(key,config_wandb[key])

wandb.init(project="new_new_graph_synthetic_moe", entity="fkddf", name=wandb_name, config=config_wandb)



## train + evaluation function 
train(args,model,train_dataloader,val_dataloader,test_dataloader,tokenizer,TO_TOKEN)

 


 
