import torch
import argparse
from model_trainer import ModelTrainer
from utils import get_cifar10
import gc

parser = argparse.ArgumentParser(description='Run model training and evaluation.')
parser.add_argument('explained_var', type=float, help='Explained variance threshold')

parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu', 
                    choices=['cpu', 'cuda'], help='Device to use for training (default: cuda if available)')

parser.add_argument('--model_name', type=str, default='vit_b_32', help='vit_b_32 or swinT')

args = parser.parse_args()
num_of_finetune = "all"
explained_var = args.explained_var
device = torch.device(args.device)
model_name = args.model_name

print("---------------------------------")
print("explained_var = ", explained_var)

# Configurations:
torch.manual_seed(233)
batch_size = 128
num_batches = 1
num_epochs = 1

# Get data
dataloader = get_cifar10(batch_size, num_batches)

if explained_var == 0.0:
    ######################## Vanilla ############################
    # Get model
    model = ModelTrainer(model_name, batch_size, num_epochs, device=device,
                    with_base=True, dataloader=dataloader, output_channels=10, num_of_finetune=num_of_finetune)

    model.train_model()
    model.inference_model()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    gc.collect()

else:
    ######################### WASI ##########################
    # Get model
    perplexity_link = {'vit_b_32':"perplexity/vit_b_32_perplexity.pkl",
                       'swinT':'perplexity/swinT_perplexity.pkl'}
    model = ModelTrainer(model_name, batch_size, num_epochs, device=device,
                    with_WASI=True, dataloader=dataloader, output_channels=10, explained_var=explained_var, num_of_finetune=num_of_finetune, perplexity_link=perplexity_link[model_name])
    model.train_model()
    model.inference_model()

    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    gc.collect()

import os
os.makedirs(f'processed_time_{args.device}', exist_ok=True)

# Create result folder
processed_time_folder = f"processed_time_{args.device}/batch_size={batch_size}/{model_name}"
if not os.path.exists(processed_time_folder):
    os.makedirs(processed_time_folder)

def check_and_write_title(file_path, title):
    if not os.path.exists(file_path) or os.path.getsize(file_path) == 0:
        with open(file_path, "a") as file:
            file.write(title)

check_and_write_title(f'{processed_time_folder}/training_time_backward.txt', "explained_var\ttime\n")
check_and_write_title(f'{processed_time_folder}/training_time_forward.txt', "explained_var\ttime\n")
check_and_write_title(f'{processed_time_folder}/training_time.txt', "explained_var\ttime\n")
check_and_write_title(f'{processed_time_folder}/inference_time.txt', "explained_var\ttime\n")

# Log results
with open(f'{processed_time_folder}/training_time_backward.txt', "a") as file:
    file.write(f"{explained_var}\t{sum(model.backward_time)}\n")
with open(f'{processed_time_folder}/training_time_forward.txt', "a") as file:
    file.write(f"{explained_var}\t{sum(model.forward_time)}\n")
with open(f'{processed_time_folder}/training_time.txt', "a") as file:
    file.write(f"{explained_var}\t{sum(model.backward_time) + sum(model.forward_time)}\n")
with open(f'{processed_time_folder}/inference_time.txt', "a") as file:
    file.write(f"{explained_var}\t{sum(model.inference_time)}\n")