import yaml
from torch.utils.data import DataLoader
from models.resnet_wrapper import ResNetWithHead
from train.train_base import train_base_model
from train.finetune_head import finetune_head
import argparse
import torch
import os
from data.dataloader import get_dataloader
from evaluate.metrics import log_per_class_accuracy, evaluate_per_class_accuracy

def dir_type(path: str) -> bool:
    if not os.path.isfile(path):
        raise argparse.ArgumentTypeError(f"'{path}' is not a valid file path.")
    if not path.endswith('.yaml'):
        raise argparse.ArgumentTypeError(f"'{path}' is not a YAML file.")
    return path

def parse_args():
    parser = argparse.ArgumentParser(description="Run experiment")
    parser.add_argument('--config', type=dir_type, default='configs/cgmt_experiment.yaml', help='Path to the config file')
    parser.add_argument('--val_n', type=int, default=None, help='Number of points to use for LLR')
    parser.add_argument('--seed', type=int, default=None, help='Seed for randomness')
    return parser.parse_args()

def load_config(path):
    with open(path, 'r') as f:
        return yaml.safe_load(f)

def main(args):
    config = load_config(args.config)
    if getattr(args,'val_n') is not None:
        config['dataset']['valid_n'] = args.val_n
    if getattr(args,'seed') is not None:
        config['general']['seed'] = args.seed

    seed = config['general']['seed']
    torch.manual_seed(seed)
    
    dataloader = get_dataloader(config['dataset'], split='train', seed = seed)
    val_dataloader = get_dataloader(config['dataset'],split='valid', seed = seed)
    test_dataloader = get_dataloader(config['dataset'],split='test', seed = seed)
    print(f'Train Size: {len(dataloader.dataset)}')
    print(f'Val Size: {len(val_dataloader.dataset)}')
    print(f'Test Size: {len(test_dataloader.dataset)}')

    for backbone in config['backbones']:
        backbone_name = backbone['name']
        for latent_dim in backbone['latent_dims']:
            print(f"Training base model: {backbone_name} with latent dim {latent_dim}")
            model = ResNetWithHead(backbone_name=backbone_name, latent_dim=latent_dim, 
                                   num_classes=config['base_model']['num_classes'])
            train_base_model(model, dataloader, config['base_model'], backbone_name, latent_dim, seed=seed, experiment_name=config['general']['experiment_name'])

            for weight in config['llr']['fairness_weights']:
                print(f"Fine-tuning last layer for: {backbone_name} with latent dim {latent_dim} and weight {weight}")
                # Ensure the base model is in evaluation mode (if frozen) for feature extraction
                model.eval()
                classifier = finetune_head(model, val_dataloader, config['llr'], weight, backbone_name, latent_dim, seed=seed, experiment_name=config['general']['experiment_name'])

                print("Evaluating per-class accuracy on test set...")
                acc_dict = evaluate_per_class_accuracy(model, classifier, test_dataloader)
                model_info = {
                    "backbone": backbone_name,
                    "latent_dim": latent_dim,
                    "fairness_weight": weight,
                    "n": len(val_dataloader.dataset),
                    "loss":config['llr']['loss']
                }
                log_per_class_accuracy(f"logs/{config['general']['experiment_name']}/per_class_eval.csv", model_info, acc_dict)

if __name__ == "__main__":
    args = parse_args()
    main(args)
