import os
import torch
import wandb
from timm.models.vision_transformer import vit_small_patch16_224, vit_tiny_patch16_224
from torchvision import models
from meta.meta_dataset import CIFAR10_oneclass, CIFAR10_nineclass #, CIFAR100, MNIST, FMNIST, CIFAR10p1, GLD23K
from central_data.dataset import CIFAR10
# from central_data.dataset import CIFAR10_Lp, CIFAR100_Lp, MNIST_Lp, FMNIST_Lp, CIFAR10p1_Lp, GLD23K_Lp
# from central_data.options import args_parser 
from meta.args import args_parser 
import torch.optim.lr_scheduler as lr_scheduler
from evaluations.train_test import model_evaluate_imagedata, model_train_imagedata, initialize_optimizer

def replace_final_layer(model, num_classes):
    # Get the type of the vit_small_patch16_224 model by instantiating it
    vit_model_type = type(vit_small_patch16_224(pretrained=False))
    
    # Check if the model is an instance of either ResNet or the Vision Transformer
    if isinstance(model, (models.ResNet, vit_model_type)):
        if hasattr(model, 'fc'):
            model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
        elif hasattr(model, 'head'):
            model.head = torch.nn.Linear(model.head.in_features, num_classes)
        else:
            raise ValueError("Final layer replacement failed: model has no 'fc' or 'head' attribute.")
    else:
        raise ValueError(f"Model type {type(model)} not supported for final layer replacement.")
    
    return model

def freeze_layers(model, fine_tune_ver):
    if fine_tune_ver == 1:
        for param in model.parameters():
            param.requires_grad = False
        # Only keep the final layer trainable
        if hasattr(model, 'fc'):
            for param in model.fc.parameters():
                param.requires_grad = True
        elif hasattr(model, 'head'):
            for param in model.head.parameters():
                param.requires_grad = True
        else:
            raise ValueError("Classification layer replacement failed")
    elif fine_tune_ver == 2:
        # Freeze all layers except for some randomly selected layers (this will require further definition)
        raise NotImplementedError("Random layer unfreezing is not yet implemented.")
    elif fine_tune_ver == 3:
        # Unfreeze all layers for full fine-tuning
        for param in model.parameters():
            param.requires_grad = True
    else:
        raise ValueError("Invalid fine_tune_ver value.")
    return model

def initialize_model(args):
    if args.model == 'resnet18':
        model = models.resnet18(pretrained=bool(args.pretrain))
    elif args.model == 'resnet34':
        model = models.resnet34(pretrained=bool(args.pretrain))
    elif args.model == 'resnet50':
        model = models.resnet50(pretrained=bool(args.pretrain))
    elif args.model == 'resnet101':
        model = models.resnet101(pretrained=bool(args.pretrain))
    elif args.model == 'resnet152':
        model = models.resnet152(pretrained=bool(args.pretrain))
    elif args.model == 'wide_resnet50_2':
        model = models.wide_resnet50_2(pretrained=bool(args.pretrain))
    elif args.model == 'wide_resnet101_2':
        model = models.wide_resnet101_2(pretrained=bool(args.pretrain))
    elif args.model == 'resnext50_32x4d':
        model = models.resnext50_32x4d(pretrained=bool(args.pretrain))
    elif args.model == 'resnext101_32x8d':
        model = models.resnext101_32x8d(pretrained=bool(args.pretrain))
    elif args.model == 'vit': # Need to implement resizing?
        model = vit_small_patch16_224(pretrained=bool(args.pretrain))
    elif args.model == 'vit_T':
        model = vit_tiny_patch16_224(pretrained=bool(args.pretrain))
    else:
        raise ValueError(f"Unsupported model type: {args.model}")

    # Replace final layer based on the number of classes in the dataset
    if args.dataset == 'CIFAR10' or args.dataset == 'CIFAR10_Lp':
        num_classes = 10
    elif args.dataset == 'CIFAR100' or args.dataset == 'CIFAR100_Lp':
        num_classes = 100
    elif args.dataset == 'MNIST' or args.dataset == 'FMNIST':
        num_classes = 10
    elif args.dataset == 'MNIST_Lp' or args.dataset == 'FMNIST_Lp':
        num_classes = 10
    elif args.dataset == 'GLD23K' or args.dataset == 'GLD23K_Lp':
        num_classes = 203 
    elif args.dataset == 'CIFAR10_oneclass' or args.dataset == 'CIFAR10_nineclass':
        num_classes = 10
    else:
        raise ValueError(f"Unsupported dataset type: {args.dataset}")

    model = replace_final_layer(model, num_classes)
    model = freeze_layers(model, args.fine_tune_ver)
    return model
    
def save_model(model, proj_name, run_name, epoch):
    save_dir = f'/directory/saved_models/{proj_name}/{run_name}/epoch_{epoch}'
    os.makedirs(save_dir, exist_ok=True)
    save_path = os.path.join(save_dir, 'model.pth')
    print(f"Saving model at {save_path}")
    torch.save(model, save_path)

def initialize_linear_warmup_scheduler(args, optimizer):
    def linear_warmup(epoch):
        if epoch < args.warmup_epochs:
            return float(epoch) / float(args.warmup_epochs) 
        return 1.0 
    
    scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=linear_warmup)
    return scheduler

def main():  
    args = args_parser()
    os.makedirs('/directory/wandblog/', exist_ok=True)
    args.eval_class = (args.train_class + 1) % 10

    dataset_class = globals()[args.dataset]
    cifar10_loader = dataset_class(batch_size=args.batch_size, random_seed=args.PySeed, data_randseed=args.DataSeed)
    dataloaders_dict = cifar10_loader.load_datasets()

    # Access DataLoaders for class 0
    train_loader = dataloaders_dict[args.train_class]['train']
    val_loader = dataloaders_dict[args.train_class]['val']
    test_loader = dataloaders_dict[args.train_class]['test']
    
    args.eval_dataset = 'CIFAR10'
    # # Load evaluation datasets
    eval_dataset_class = globals()[args.eval_dataset]
    eval_dataset = eval_dataset_class(batch_size=args.batch_size, random_seed=args.PySeed, data_randseed=args.DataSeed)
    eval_train_loader, eval_val_loader, eval_test_loader = eval_dataset.load_datasets()

    args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = initialize_model(args)
    model.to(args.device)

    # Initialize optimizer and scheduler
    optimizer = initialize_optimizer(args, model)
    scheduler = initialize_linear_warmup_scheduler(args, optimizer)

    # Set up Weights & Biases logging
    project_name = args.base_project_name
    project_name += f"Meta_CIFAR10_{args.model}_{args.epochs}Epk"
    run_name = f"{args.train_class}_{args.PySeed}/{args.DataSeed}" 
    args.run_name_for_wandb_API = run_name # Issues with floating point precision in args.lr in API versus core Python version (loading models) 
    os.makedirs('/directory/wandblog/', exist_ok=True)
    os.environ['WANDB_DIR'] = '/directory/wandblog/'
    os.environ["WANDB__SERVICE_WAIT"] = "300"
    wandb.init(
        project=project_name,
        name=run_name,
        config=args.__dict__
    )

    # Training loop
    for epoch in range(args.epochs):
        # Evaluation loop
        if epoch % args.eval_every == 0: 
            save_model(model, project_name, run_name, epoch)

            # Training loop
            eval_train_loss, eval_train_accuracy, eval_train_datapoints = model_evaluate_imagedata(model, eval_train_loader, args.device, loss_function = torch.nn.functional.cross_entropy)

            # Validation loop
            val_loss, val_accuracy, val_datapoints = model_evaluate_imagedata(model, val_loader, args.device, loss_function = torch.nn.functional.cross_entropy)
            eval_val_loss, eval_val_accuracy, eval_val_datapoints = model_evaluate_imagedata(model, eval_val_loader, args.device, loss_function = torch.nn.functional.cross_entropy)
            
            # Test loop
            test_loss, test_accuracy, test_datapoints = model_evaluate_imagedata(model, test_loader, args.device, loss_function = torch.nn.functional.cross_entropy)
            eval_test_loss, eval_test_accuracy, eval_test_datapoints = model_evaluate_imagedata(model, eval_test_loader, args.device, loss_function = torch.nn.functional.cross_entropy)

        train_loss, train_accuracy, train_datapoints = model_train_imagedata(model, train_loader, args.device, optimizer, loss_function = torch.nn.functional.cross_entropy)

        if epoch % args.eval_every == 0: 
            # Calculate total traindata metrics
            total_traindata_accuracy = (train_accuracy * train_datapoints + val_accuracy * val_datapoints + test_accuracy * test_datapoints) / (train_datapoints + val_datapoints + test_datapoints)
            total_traindata_loss = (train_loss * train_datapoints + val_loss * val_datapoints + test_loss * test_datapoints) / (train_datapoints + val_datapoints + test_datapoints)

            # Calculate eval total traindata metrics
            eval_total_traindata_accuracy = (eval_train_accuracy * eval_train_datapoints + eval_val_accuracy * eval_val_datapoints + eval_test_accuracy * eval_test_datapoints) / (eval_train_datapoints + eval_val_datapoints + eval_test_datapoints)
            eval_total_traindata_loss = (eval_train_loss * eval_train_datapoints + eval_val_loss * eval_val_datapoints + eval_test_loss * eval_test_datapoints) / (eval_train_datapoints + eval_val_datapoints + eval_test_datapoints)
        
            # Log to Weights & Biases
            wandb.log({
                'epoch': epoch,
                'train_loss': train_loss,
                'train_accuracy': train_accuracy,
                'val_loss': val_loss,
                'val_accuracy': val_accuracy,
                'test_loss': test_loss,
                'test_accuracy': test_accuracy,
                'eval_train_loss': eval_train_loss,
                'eval_train_accuracy': eval_train_accuracy,
                'eval_val_loss': eval_val_loss,
                'eval_val_accuracy': eval_val_accuracy,
                'eval_test_loss': eval_test_loss,
                'eval_test_accuracy': eval_test_accuracy,
                'total_traindata_accuracy': total_traindata_accuracy,
                'total_traindata_loss': total_traindata_loss,
                'eval_total_traindata_accuracy': eval_total_traindata_accuracy,
                'eval_total_traindata_loss': eval_total_traindata_loss
            })
        # Step the scheduler
        scheduler.step()

if __name__ == '__main__':
    main()
