import argparse
import torch
import torch.nn as nn
from train.training import train_and_infer
import os
from helper import log_and_print
import time
from helper import set_seed, read_config_file, bool_from_string, \
    retrieve_model_and_builder, get_dataloaders, update_args, timer_decorator, Args
import wandb
import ast
from train.utils import Trainer



@timer_decorator
def setup_training():
    args = Args().get_args()
    train_dataloaders, test_dataloaders = get_dataloaders()
    net, lora_builder = retrieve_model_and_builder()
    criterion = nn.CrossEntropyLoss()

    trainer = Trainer(criterion, lora_builder, args.device)
    net = train_and_infer(
        model=net,
        train_dataloaders=train_dataloaders,
        test_dataloaders=test_dataloaders,
        trainer=trainer
    )

    num_params = sum(p.numel() for p in net.parameters())
    if args.use_wandb:
        wandb.log({"num_params": num_params})
    log_and_print(f"FINAL: Total no.of parameters after CL: {num_params}", args.logger, args.verbose)

    if args.save_model:
        model_filename = os.path.join(
            args.log_dir, f"{time.strftime('%Y%m%d_%H%M%S')}.pt"
        )
        torch.save(net.state_dict(), model_filename)
        args.logger.info(f"Saved model: {model_filename}")



def main():
    parser = argparse.ArgumentParser(description="Train the network with specified configuration")
    parser.add_argument("--config", type=str, required=True, help="Path to the configuration file")
    args = parser.parse_args()
    config = read_config_file(args.config)

    for section in config.sections():
        for key, value in config.items(section):
            if value.lower() == 'true' or value.lower() == 'false':
                parser.add_argument(f"--{key}", type=bool, default=bool_from_string(value), help=f"{key} from config")
            else:
                try:
                    if '.' in value or 'e' in value :  # float
                        parser.add_argument(f"--{key}", type=float, default=float(value), help=f"{key} from config")
                    else:  # int or str
                        parser.add_argument(f"--{key}", type=int, default=int(value), help=f"{key} from config")
                except ValueError:
                    parser.add_argument(f"--{key}", type=str, default=value, help=f"{key} from config")

    # Parse the rest of the arguments
    args = parser.parse_args()

    # Handle string-to-list conversion for specific arguments
    if hasattr(args, 'target_modules') and args.target_modules:
        args.target_modules = ast.literal_eval(args.target_modules)

    args = update_args(args)
    if args.use_wandb:
        os.environ["WANDB_API_KEY"] = config.WANDB_API_KEY
        wandb.init(project=config.WANDB_PROJECT, entity=config.WANDB_ENTITY)
        wandb.config.update(args)

    # Set random seeds for reproducibility
    set_seed(args.seed)

    # Set global instance of arguments
    args_global = Args()
    args_global.set_args(args)

    # prepare and train
    setup_training()


if __name__ == "__main__":
    main()
