import torch
import pickle
import argparse
import os
import csv
import random
import logging
from dataset import LongRangeDataset
from models import TransformerLanguageModel, LSTMLanguageModel, LRULanguageModel, MambaLanguageModel
from utils import ModelHandler
from datetime import datetime
import faulthandler

faulthandler.enable()
def config_seed(seed):
    torch.manual_seed(seed)
    random.seed(seed)

parser = argparse.ArgumentParser()
parser.add_argument('--data_file', type=str, help="Location of data", 
    default='')
parser.add_argument('--save_location', type=str, help="Where to save model", 
    default='./models/')
parser.add_argument('--model_to_load', type=str, default="", 
    help="Load and train existing model")
parser.add_argument('--model_type', type=str, default="lstm", 
    choices=["lstm", "transformer", "lru", "mamba"], 
    help="Which architecture to use")
parser.add_argument('--batch_size', type=int, help="Batch size", 
    default=256)
parser.add_argument('--epochs', type=int, help="Training epochs", 
    default=500)
parser.add_argument('--embed_dim', type=int, default=16,
    help="Embedding dimension, also d_model for transformer")
parser.add_argument('--hidden_size', type=int, default=128, 
    help="Hidden size for LSTM")
parser.add_argument('--num_layers', type=int, default=2, 
    help="Number of layers in model")
parser.add_argument('--num_heads', type=int, default=4, 
    help="Number of attention heads in transformer")
parser.add_argument('--ff_dim', type=int, default=64, 
    help="Feedforward dimension for transformer")
parser.add_argument('--r_min', type=float, default=0.9, 
    help="r_min for LRU")
parser.add_argument('--r_max', type=float, default=0.999, 
    help="r_max for LRU")
parser.add_argument('--lr', type=float, default=1e-3, 
    help="Learning rate")
parser.add_argument('--data_percent', type=float, default=1.0,
    help="Percentage of training data to use")
parser.add_argument('--seed', type=int, help="Random seed")
parser.add_argument('--force_cpu', action='store_true', 
    help="Force use of cpu")
parser.add_argument('--no_attention', action='store_true', 
    help="LSTM or LRU without attention")
parser.add_argument('--pos_enc', action='store_true', 
    help="Add position encoding for transformers")
parser.add_argument('--d_conv', type=int, default=4, 
    help="d_conv for mamba")
parser.add_argument('--d_state', type=int, default=16, 
    help="d_state for mamba")
parser.add_argument('--expand', type=int, default=2, 
    help="expand for mamba")
args = parser.parse_args()

if args.seed is not None:
    seed = args.seed
else:
    seed = random.randint(1, 10000)

config_seed(seed)

logging.basicConfig(filename='training.log', filemode='a', 
    format='%(asctime)s - %(message)s', level=logging.INFO)

dataset = LongRangeDataset(args.data_file, args.data_percent)
print(len(dataset.all_data['train']))

if args.force_cpu:
    device = torch.device('cpu')
else:
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if args.model_to_load != "":
    model = torch.load(open(args.model_to_load, "rb"))
else:
    if args.model_type == "lstm":
        model = LSTMLanguageModel(args.embed_dim, dataset.vocab_size, 
            args.hidden_size, args.num_layers, pad_idx=dataset.pad_idx, 
            device=device, use_attention=not args.no_attention).to(device)
        add_model_info = ["hidden_size_{}".format(args.hidden_size)]
        if args.no_attention:
            add_model_info.append("no_attention")
        else:
            add_model_info.append("with_attention")
    elif args.model_type == "lru":
        model = LRULanguageModel(args.embed_dim, dataset.vocab_size, 
            args.hidden_size, args.num_layers, pad_idx=dataset.pad_idx, 
            device=device, r_min=args.r_min, r_max=args.r_max,
            use_attention=not args.no_attention).to(device)
        add_model_info = ["hidden_size_{}".format(args.hidden_size), 
                            "r_min_{:.4f}".format(args.r_min), 
                            "r_max_{:.4f}".format(args.r_max)]
        if args.no_attention:
            add_model_info.append("no_attention")
        else:
            add_model_info.append("with_attention")
    elif args.model_type == "transformer":
        model = TransformerLanguageModel(args.embed_dim, dataset.vocab_size, 
            args.num_layers, args.num_heads, args.ff_dim, pad_idx=dataset.pad_idx,
            device=device, pos_enc=args.pos_enc).to(device)
        add_model_info = ["heads_{}".format(args.num_heads), 
                            "ff_dim_{}".format(args.ff_dim)]
        if args.pos_enc:
            add_model_info.append("pos_enc")
        else:
            add_model_info.append("no_pos_enc")
    elif args.model_type == "mamba":
        model = MambaLanguageModel(args.embed_dim, dataset.vocab_size, 
            d_conv=args.d_conv, d_state=args.d_state, 
            expand=args.expand, pad_idx=dataset.pad_idx, 
            device=device).to(device)
        add_model_info = ["d_conv_{}".format(args.d_conv), 
                            "d_state_{}".format(args.d_state), 
                            "expand_{}".format(args.expand)]

num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print("Model with {} parameters".format(num_params))

optim = torch.optim.Adam(model.parameters(), lr=args.lr)

if args.model_to_load == "":
    model_info = [args.model_type, args.data_file.split("/")[-1].split(".")[0],
        "parameters_{}".format(num_params), "embed_{}".format(args.embed_dim), 
        "layers_{}".format(args.num_layers), "lr_{:.5f}".format(args.lr), 
        "seed_{}".format(seed), "batch_size_{}".format(args.batch_size
            )] + add_model_info
    model_name = '{}.model'.format('___'.join(model_info))
else:
    model_name = args.model_to_load.split("/")[-1]

curr_dt = datetime.now()
train_id = int(round(curr_dt.timestamp()))
logging.info("ID: {}, beginning training on model {}".format(train_id, 
    model_name))

handler = ModelHandler(model, optim, dataset, device, args.save_location, 
    model_name)
handler.train_model(args.batch_size, args.epochs)

handler.load_model(os.path.join(args.save_location, model_name), 
    args.batch_size)
test_acc, results = handler.test_model("beam_search", split="test", 
    record_results=True)
val_acc, _ = handler.test_model("beam_search", split="val", record_results=True)

pickle.dump(results, open('results/{}_results.pickle'.format(model_name),'wb'))

if args.model_to_load == "":
    record_string = ("ID: {}\nFinished training model for {} epochs. "
        "Saved as {}. Full value of float hyperparameters:\n\tLearning"
        " rate: {} [LRU]\n Best validation loss: {:.2f}\nTest Accuracy:"
        " {:.2f} achieved wtih beam search")
    if args.model_type == "lru":
        record_string = record_string.replace(
            "[LRU]", "\n\tr_min:{}\n\tr_max:{}")
        logging.info(record_string.format(train_id, args.epochs, model_name, 
            args.lr, args.r_min, args.r_max, handler.best_val_loss, 
            test_acc * 100))
    else:
        record_string = record_string.replace("[LRU]", "")
        logging.info(record_string.format(train_id, args.epochs, model_name, 
            args.lr, handler.best_val_loss, test_acc * 100))
else:
    record_string = ("ID:{}\nTrained model {} for an additional {} "
        "epochs with learning rate {}. \n Best validation loss: "
        "{:.2f}\nTest Accuracy: {:.2f} achieved wtih beam search")
    logging.info(record_string.format(train_id, model_name, args.epochs, model_name, 
        args.lr, handler.best_val_loss, test_acc * 100))

task, mode = args.data_file.split("/")[-1].split(".")[0].split("_")
if args.model_type != "transformer":
    ff_dim = 0
    num_heads = 1
    attention = int(not args.no_attention)
    hidden_size = args.hidden_size
    pos_enc = None
if args.model_type == "transformer":
    attention = None
    hidden_size = 0
    ff_dim = args.ff_dim
    num_heads = args.num_heads
    pos_enc=int(args.pos_enc)

to_write = [args.model_type, task, mode, num_params, args.embed_dim, 
            args.num_layers, args.lr, seed, args.batch_size, hidden_size,
            attention, num_heads, ff_dim, pos_enc, val_acc]
csvfile = open('model_results.csv', 'a', newline='\n')
csvwriter = csv.writer(csvfile, delimiter=',')
csvwriter.writerow(to_write)
csvfile.close()