import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import math
import numpy as np
import seaborn as sns
from data.generate_data import gen_data
from data.data_preprocess import block_split,uniform_split,skew_split,strong_skew_token_split
import os
from models.MLP import ModuloClassifier,ModuloClassifier_noEmb
from models.simple_transformer import SimpleTransformerDecoder,Causal_Transformer
from utils.custom_opt import adam_emb_wd,CoordinatedOptimizer
from experiments.trainer  import trainer_mlp,trainer_transformer
from utils.plotting import plot_accuracy,plot_accuracy_21
from torch.utils.data import DataLoader, TensorDataset
from data.generate_data import ArithmeticDataset
import argparse

def get_args():
    parser = argparse.ArgumentParser(description="Configuration for the model and dataset.")

    # Add arguments corresponding to the config class
    parser.add_argument("--p", type=int, default=97, help="Prime number for dataset generation.")
    parser.add_argument("--group", type=str, default='amodp', help="Group of data.")
    parser.add_argument("--split", type=float, default=0.5, help="Train portion.")
    parser.add_argument(
        "--batch_size", 
        type=int, 
        default=min(512, int(0.5 * 0.5 * (97 - 1) ** 2)), 
        help="Batch size."
    )
    parser.add_argument("--shuffle", type=bool, default=True, help="Shuffle the train loader every epoch or not.")
    parser.add_argument("--sample", type=str, default='block', help="Sampling strategy.")
    parser.add_argument("--device", type=int, default=0, help="device.")
    parser.add_argument("--lr", type=float, default=0.001, help="Learning rate.")
    parser.add_argument("--wd", type=float, default=1, help="Weight decay.")
    parser.add_argument("--wd_scale", type=float, default=1, help="Embedding weight decay scale.")
    parser.add_argument("--epochs", type=int, default=4000, help="Number of epochs.")
    parser.add_argument("--print", type=bool, default=False, help="Printing.")
    parser.add_argument("--optimizer", type=str, default='adam_lr', help="Optimizer class.")
    parser.add_argument("--model", type=str, default='mlp', help="Model type.")
    parser.add_argument("--config", type=str, default=None, help="Model config.")
    parser.add_argument("--d_model", type=int, default=128, help="Model dimension or embedding size.")
    parser.add_argument("--seed", type=int, default=42, help="seed")
    parser.add_argument("--seq_len", type=int, default=4, help="Sequence length.")
    parser.add_argument("--from_pretrained", type=str, default=None, help="Start from pretrained layer.")
    parser.add_argument("--save", type=bool, default=True, help="Save model flag.")
    parser.add_argument("--num_heads", type=int, default=4, help="Number of heads in transformer.")
    parser.add_argument("--num_layers", type=int, default=2, help="Number of layers in transformer.")
    parser.add_argument("--pretrain_path", type=str, default= 'checkpoints/modp/mlp_lr_0.001_wd_0.001_split_0.3_BS_512_sample_random.pth', help="Start from pretrained layer.")
    return parser.parse_args()

args = get_args()
print(args)
seed = args.seed
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
device = torch.device(f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu')


def main(para):
    #load data
    data,vocab_size = gen_data(para.group,para.p)
    para.vocab_size = vocab_size
    # make different sampling splits according to section 4.1
    train_data_uni, test_orig = uniform_split(data, para.p, para.split)
    new_train, test_data = train_test_split(test_orig, test_size=0.2/(1-para.split), random_state=seed)
    train_combined = torch.cat([new_train, train_data_uni], dim=0)
    train_data_rand, _ = train_test_split(train_combined, test_size=1-(para.split/0.8), random_state=seed)
    # train_data_skew = strong_skew_token_split(train_combined,(para.split/0.8),dominant_fraction=0.62)
    # preprocess the data
    if para.sample == 'uniform':
        print('Running on uniform data')
        train_data = train_data_uni
    elif para.sample == 'skew':
        print('Running on skew data')
        train_data,_ = skew_split(data,para.p,para.split)
    elif para.sample == 'block':
        print('Running on block data')
        train_data, test_data = block_split(data,para.p,para.split)
    elif para.sample == 'random':
        print('Running on random split data')
        train_data = train_data_rand
    else: print('split is not defined')

    train_data = train_data.to(device)
    test_data = test_data.to(device)
    # create dataset
    train_dataset = TensorDataset(train_data[:,:-1], train_data[:,-1])
    test_dataset = TensorDataset(test_data[:,:-1], test_data[:,-1])
    # create dataloader
    train_loader = DataLoader(train_dataset, batch_size=para.batch_size, shuffle=para.shuffle)
    test_loader = DataLoader(test_dataset, batch_size=para.batch_size, shuffle=para.shuffle)

    global model_identifier
    model_identifier  =f'{args.model}_lr_{args.lr}_wd_{args.wd}_split_{args.split}_BS_{args.batch_size}_sample_{args.sample}'
    if para.wd_scale != 1 and para.optimizer == 'adam_custom':
        model_identifier +=f'_scaleWD_{para.wd_scale}'
    if para.optimizer == 'adam_lr':
        model_identifier +=f'_adam_{para.wd_scale}'
    if not para.shuffle:
        model_identifier +=f'_Not_shuffled'
    if para.from_pretrained:
        model_identifier +=f'_preEmb'
    # Create directory to save the things for this model and this dataset
    global directory , path_to_save
    directory = f"checkpoints/{para.group}/"
    if not os.path.exists(directory):
        os.makedirs(directory)
    print('Initialized experiemnt instance:',directory,model_identifier)
    path_to_save = para.pretrain_path
    
    # define the loss as cross entropy
    criterion = nn.CrossEntropyLoss()
    # Initialize model, loss function, and optimizer, support causal transfomer, mlp, mlp without embedding layer
    if para.model.lower() == 'mlp':
        model = ModuloClassifier(para.vocab_size, para.d_model, para.seq_len, pretrained_path=path_to_save if para.from_pretrained else None).to(device)
    elif para.model.lower() == 'mlp_noemb':
        model = ModuloClassifier_noEmb(para.vocab_size, para.d_model, para.seq_len, pretrained_path=para.from_pretrained).to(device)
    elif para.model.lower() == 'transformer':
        model = SimpleTransformerDecoder(para.vocab_size, para.d_model, para.num_heads, para.num_layers,pretrained_path=None).to(device)
    elif para.model.lower() == 'causal_transformer':
        model = Causal_Transformer(para.vocab_size, para.d_model, para.num_heads, para.num_layers,pretrained_path=None).to(device)
    else:
        print('model is not supported')
    # our optimizer called adam_lr
    if para.optimizer.lower() == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=para.lr, weight_decay=para.wd)
    if para.optimizer.lower() == 'lbfgs':
        optimizer = torch.optim.LBFGS(model.parameters(), lr=1, max_iter=20, history_size=10)
    if para.optimizer.lower() == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(), lr=para.lr, weight_decay=para.wd)
    if para.optimizer.lower() == 'adam_custom':
        optimizer = adam_emb_wd(model, lr=para.lr, wd=para.wd, wd_scale=para.wd_scale)
    if para.optimizer.lower() == 'adamw':
        optimizer = torch.optim.AdamW(model.parameters(), lr=para.lr, weight_decay=para.wd)
    if para.optimizer.lower() == 'adam_lr':
        # Automatically group parameters
        embedding_params = []
        other_params = []
        # scale the embedding layer learning rate as per the wd_scale
        for name, param in model.named_parameters():
            if "embedding" in name or "layers" in name and "attn" in name:
                embedding_params.append(param)
            else:
                other_params.append(param)

        # Define optimizer with separate learning rates
        optimizer = torch.optim.AdamW([
            {'params': embedding_params, 'lr': para.lr * para.wd_scale, 'weight_decay': para.wd},
            {'params': other_params, 'lr': para.lr ,'weight_decay': para.wd}
        ])
    if para.optimizer.lower() == 'coordinate_SGD':
        # Automatically group parameters
        embedding_params = []
        other_params = []

        for name, param in model.named_parameters():
            if "embedding" in name:
                embedding_params.append(param)
            else:
                other_params.append(param)
        
        optimizer = CoordinatedOptimizer(
                embedding_params=embedding_params,
                other_params=other_params,
                lr_embedding=para.lr ,
                lr_other=para.lr,
                wd_embedding=para.wd ,
                wd_other=para.wd
                )


    # trainer loop
    if para.model.lower() in ['mlp', 'mlp_noemb']:
        print('Start MLP training:',model_identifier)
        results = trainer_mlp(train_loader,test_loader,model,optimizer,criterion,para)
    if para.model.lower() in ['transformer', 'causal_transformer']:
        print('Start Transformer training:',model_identifier)
        results = trainer_transformer(train_loader,test_loader,model,optimizer,criterion,para)
    return results




# run the trainer 
# results = main(args)
# Average the results for three seeds 42, 12, 93
results_42 = main(args)
args.seed = 12
results_12 = main(args)
args.seed = 93
results_93 = main(args)

results = {}
for key in results_42.keys():
    if key in ['train_acc', 'test_acc']:
        results[key] = [(a + b + c) / 3 for a, b, c in zip(results_42[key], results_12[key], results_93[key])]
    else:
        results[key] = results_42[key]


# plot the results initilally with correct optimizaion steps (f(bs,epochs,total_examples))
plt.figure(figsize=(12, 8))
N = (args.p-1)**2*args.split
step_per_epoch = int(N/args.batch_size)

plt.plot([i * step_per_epoch for i in range(1, args.epochs+1)], results['train_acc'], label=f'Train ', color='navy', linestyle='-', linewidth=2)
plt.plot([i * step_per_epoch for i in range(1, args.epochs+1)], results['test_acc'], label=f'Test', color='crimson', linestyle='--', linewidth=2)

plt.xscale('log')   

plt.xlabel('Optimization Steps', fontsize=16)
plt.ylabel('Accuracy', fontsize=16)
plt.legend(fontsize=16)
plt.grid(True)
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)
# save results, and the figure
torch.save(results, os.path.join(directory, f'{model_identifier}.pth'))
plt.savefig(f'{directory}/{model_identifier}.png', format='png', bbox_inches='tight')
