import torch
import argparse
from model_trainer import ModelTrainer
from utils import get_cifar10
import gc


parser = argparse.ArgumentParser(description='Run model training and evaluation.')
parser.add_argument('explained_var', type=float, help='Explained variance threshold')

parser.add_argument('--method', type=str, choices=['vanilla','INSTANT', 'LBPWHT', "GF"], default='vanilla', help='Method to use: vanilla or LBPWHT or INSTANT or Gradient Filtering')

parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu', 
                    choices=['cpu', 'cuda'], help='Device to use for training (default: cuda if available)')

parser.add_argument('--model_name', type=str, default='vit_b_32', help='vit_b_32 or swinT or efficientformer_l1')
parser.add_argument('--num_batch', type=int, default=780, help='Training batch each epoch')
parser.add_argument('--over_sampling', type=int, default=0, help='Oversampling value')

args = parser.parse_args()

# number of layers to finetune
# num_of_finetune = "all"
explained_var = args.explained_var

method = args.method
device = torch.device(args.device)
model_name = args.model_name

if model_name == "efficientformer_l1":
    num_of_finetune = 10
else:
    num_of_finetune = 9

print("---------------------------------")
print("explained_var = ", explained_var, " | method = ", args.method)

# Configurations:
torch.manual_seed(233)
batch_size = 64
num_batches = args.num_batch
num_epochs = 1


# Get data
dataloader = get_cifar10(batch_size, num_batches)

if method == 'vanilla':
    ######################## Vanilla ############################
    # Get model
    model = ModelTrainer(model_name, batch_size, num_epochs, device=device,
                    with_base=True, dataloader=dataloader, output_channels=10, num_of_finetune=num_of_finetune)

    model.train_model()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    gc.collect()

elif method == 'INSTANT':
    ######################### INSTANT ##########################
    # Get model
    dataloader = get_cifar10(batch_size, num_batches+5)

    model = ModelTrainer(model_name, batch_size, num_epochs, device=device, with_INSTANT = True, over_sampling= args.over_sampling, dataloader=dataloader, output_channels=10, explained_var=explained_var, num_of_finetune=num_of_finetune)
    model.train_model()

    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    gc.collect()

elif method == 'LBPWHT':
    ######################### LBP-WHT-4 ##########################
    # Get model

    model = ModelTrainer(model_name, batch_size, num_epochs, device=device, with_LBPWHT = True, dataloader=dataloader, output_channels=10, num_of_finetune=num_of_finetune)
    model.train_model()

    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    gc.collect()

elif method == 'GF':
    ######################### Gradient Filtering ##########################
    # Get model

    model = ModelTrainer(model_name, batch_size, num_epochs, device=device, with_GF = True, dataloader=dataloader, output_channels=10, num_of_finetune=num_of_finetune)
    model.train_model()
    # model.inference_model()

    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    gc.collect()


import os
os.makedirs(f'processed_time_{args.device}', exist_ok=True)

# Create result folder
processed_time_folder = f"processed_time_{args.device}/{method}/{model_name}_batch_size={batch_size}"
if not os.path.exists(processed_time_folder):
    os.makedirs(processed_time_folder)

def check_and_write_title(file_path, title):
    if not os.path.exists(file_path) or os.path.getsize(file_path) == 0:
        with open(file_path, "a") as file:
            file.write(title)

if method=="INSTANT":
    check_and_write_title(f'{processed_time_folder}/training_time_backward.txt', "explained_var\ttime\tmethod\n")
    check_and_write_title(f'{processed_time_folder}/training_time_forward.txt', "explained_var\ttime\tmethod\n")
    # Log results
    with open(f'{processed_time_folder}/training_time_backward.txt', "a") as file:
        file.write(f"{explained_var}\t{sum(model.backward_time) / (num_batches) }\t{method}\n")
    with open(f'{processed_time_folder}/training_time_forward.txt', "a") as file:
        file.write(f"{explained_var}\t{sum(model.forward_time) / (num_batches)}\t{method}\n")
else:
    check_and_write_title(f'{processed_time_folder}/training_time_backward.txt', "explained_var\ttime\tmethod\n")
    check_and_write_title(f'{processed_time_folder}/training_time_forward.txt', "explained_var\ttime\tmethod\n")
    # Log results
    with open(f'{processed_time_folder}/training_time_backward.txt', "a") as file:
        file.write(f"None\t{sum(model.backward_time) / num_batches }\t{method}\n")
    with open(f'{processed_time_folder}/training_time_forward.txt', "a") as file:
        file.write(f"None\t{sum(model.forward_time) / num_batches}\t{method}\n")