# Standard imports
import argparse
import gc
import os
from pathlib import Path
from tqdm import tqdm
from sys import exit
import wandb
from matplotlib import pyplot as plt
import numpy as np
import torch
from torch import nn, optim
from torch.autograd import Variable
import torch.nn.functional as F
import torch.nn.utils
from torch.nn.utils import weight_norm

# Our imports
from train import train
from models import coRNN
from model_utils import get_model # for transformers

# mamba imports
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from mamba_ssm.models.config_mamba import MambaConfig

parser = argparse.ArgumentParser(description='training parameters')

# Logging
parser.add_argument_group('Logging')
parser.add_argument('--run_name', type=str, default='Neural Wave Field',
                    help='name of run for wandb')
parser.add_argument('--log_interval', type=int, default=100,
                    help='log interval')
parser.add_argument('--dataset', type=str, default='copy',
                    help='dataset name for wandb')
parser.add_argument('--use_wandb', type=str, default='False',
                    help='whether to log to wandb')

# Data Args
parser.add_argument_group('Data')
parser.add_argument('--T', type=int, default=20,
                    help='length of sequences')
parser.add_argument('--mem_len', type=int, default=10,
                    help='length of sequences')

# Model Args
parser.add_argument_group('Model')
parser.add_argument('--model', type=str, default='wavernn',
                   help='model to train and evaluate')
parser.add_argument('--n_hid', type=int, default=100,
                    help='hidden size of recurrent net')
parser.add_argument('--n_in', type=int, default=1,
                    help='input channels')
parser.add_argument('--n_out', type=int, default=1,
                    help='output channels')
parser.add_argument('--n_ch', type=int, default=1,
                    help='Num hidden state channels')
parser.add_argument('--act', type=str, default='relu',
                    help='hidden state activation')
parser.add_argument('--ksize', type=int, default=3,
                    help='Hidden Kernelsize')
parser.add_argument('--n_layers', type=int, default=1,
                    help='number of layers in network, currently only impelemented for mamba and TF models')
parser.add_argument('--heads', default=2, type=int, 
                    help="Number of heads for the transformer models.")


# Training Args
parser.add_argument_group('Training')
parser.add_argument('--loss', type=str, default='mse',
                    help='loss function')
parser.add_argument('--max_steps', type=int, default=1000,
                    help='max learning steps')
parser.add_argument('--batch', type=int, default=128,
                    help='batch size')
parser.add_argument('--batch_test', type=int, default=50,
                    help='size of test set')
parser.add_argument('--lr', type=float, default=1e-3,
                    help='learning rate')
parser.add_argument('--only_last', type=str, default='True',
                    help='Only train on memory ouputs')
parser.add_argument('--grad_clip', type=float, default=0.0)


if __name__ == "__main__":
  args = parser.parse_args()
  args.run_name = args.run_name + ' ' + args.dataset + ', T=' + str(args.T) + ', nhid=' + str(args.n_hid)
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  print(args.model)

  if args.model == 'mamba':
    #hard coded vocab_size but all experiments use 0-9 for copying tasks
    vocab_size = 10 
    
    # mamba/tf require CE loss
    args.loss = 'crossentropy'
    
    # won't work without gpu 
    assert device.type == 'cuda', 'GPU is needed for mamba'
    
    config = MambaConfig(d_model=args.n_hid, n_layer=args.n_layers, vocab_size=vocab_size, 
                         pad_vocab_size_multiple=1, tie_embeddings=False)
    model = MambaLMHeadModel(config).to(device)
    
  elif args.model == 'wavernn':
    model = coRNN(n_inp=args.n_in,
                n_out=args.n_out,
                n_hid=args.n_hid,
                n_ch=args.n_ch,
                act=args.act,
                ksize=args.ksize).to(device)
    
  elif args.model in ['T_alibi','T_rope','T_nope']:
    #hard coded vocab_size but all experiments use 0-9 for copying tasks
    vocab_size = 10 

    # mamba/tf require CE loss
    args.loss = 'crossentropy'
    
    model = get_model(args, vocab_size).to(device)
    
  else:  
    raise NotImplementedError

  print(args)
  
  data, label, pred, y_seq = train(model,args,device)
