from YParams import YParams
import os
import wandb
import argparse
import sys
import torch
import numpy as np
from train.trainer import *
from utils.core_utils import *
from models.model_handler import get_model
from dataloader.dataset_handler import get_data_module
from test.evaluate import *
from test.equivarinace_tester import *
from thop import profile
from fvcore.nn import FlopCountAnalysis
from torchsummary import summary
from thop import clever_format
import random

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
convert_string_to_bool = lambda x: x.lower() in ['true', '1']

if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument("--config", nargs="?", default="base_config", type=str)
    parser.add_argument("--epochs", nargs="?", default=None, type=int)
    parser.add_argument("--random_seed", nargs="?", default=42, type=int)
    parser.add_argument("--debug", nargs="?", default=None, type=str)
    parser.add_argument("--dropout_rate", nargs="?", default=-2.0, type=float)
    parser.add_argument("--weight_decay", nargs="?", default=-1.0, type=float)
    parser.add_argument("--apply_antialiasing", nargs="?", default=None, type=str)

    parsed_args = parser.parse_args()

    config = parsed_args.config 
    print("Loading config", config)
    params = YParams('./config/config_1.yaml', config, print_params=True)

    if parsed_args.random_seed is not None:
        params.random_seed = parsed_args.random_seed
        print("Overriding random seed to", params.random_seed)
    if parsed_args.epochs is not None:
        params.epochs = parsed_args.epochs
        print("Overriding epochs to", params.epochs)
    if parsed_args.debug is not None:
        params.debuging = convert_string_to_bool(parsed_args.debug)
        print("Overriding debug to", params.debuging)
    if parsed_args.dropout_rate > 0.0:
        params.dropout_rate = parsed_args.dropout_rate
        print("Overriding dropout_rate to", params.dropout_rate)
    if parsed_args.weight_decay > 0.0:
        params.weight_decay = parsed_args.weight_decay
        print("Overriding weight_decay to", params.weight_decay)
    if parsed_args.apply_antialiasing is not None:
        params.apply_antialiasing = convert_string_to_bool(parsed_args.apply_antialiasing)
        print("Overriding apply_antialiasing to", params.apply_antialiasing)


    torch.manual_seed(params.random_seed)
    random.seed(params.random_seed)
    np.random.seed(params.random_seed)

    params.config = config
    # Set up WandB logging
    params.wandb_name = config
    params.wandb_group = params.model
    if params.wandb_log:
        wandb.login(key=get_wandb_api_key())
        wandb.init(
            config=params,
            name=params.wandb_name,
            group=params.wandb_group,
            project=params.wandb_project,
            entity=params.wandb_entity)

    model = get_model(params)

    test_input = torch.randn(1, params.in_feature, 32, 32).to(params.device)
    model = model.to(params.device)
    

    data_module = get_data_module(params)

    train_dataloader = data_module.train_dataloader()
    val_dataloader = data_module.val_dataloader()
    test_dataloader = data_module.test_dataloader()
    un_augmented_test_dataloader = data_module.un_augmented_test_dataloader()
    
    if not params.debuging:
        train_classification(
            model=model,
            train_loader=train_dataloader,
            val_loader=val_dataloader,
            test_loader=test_dataloader,
            params=params,
        )
    with torch.no_grad():
        evaluate_classification_smi(model=model, test_loader=un_augmented_test_dataloader, params=params, wanbd_log=params.wandb_log)
        evaluate_classification_orbit(model=model, test_loader=un_augmented_test_dataloader, params=params, wanbd_log=params.wandb_log)
        evaluate_classification_smi(model=model, test_loader=un_augmented_test_dataloader, params=params, wanbd_log=params.wandb_log, random_noise=True, eval_suffix="Noi_")
        evaluate_classification(model=model, test_loader=un_augmented_test_dataloader, params=params, wanbd_log=params.wandb_log, eval_suffix="un_")
        equivariance_tester_classification(model=model,
                                        params=params,
                                            dataloader=test_dataloader,
                                            device=params.device,
                                            number_of_samples=params.n_tests,
                                            wanbd_log=params.wandb_log)
            



    if params.wandb_log:
        wandb.finish()
