import argparse
import wandb
import torch

from utils import *
from torch.utils.data import DataLoader, ConcatDataset

@time_decorator
def main(args):
    if args.method == 'ours_mnar':
        args.mnar = True
    else:
        args.mnar = False
    if args.task_name in ['fashionmnist']:
        from models.resnet18 import (get_cut_dim, get_z_dim)
    elif args.task_name in ['modelnet10']:
        from models.modelnet import (get_cut_dim, get_z_dim)
    elif args.task_name in ['hapt', 'isolet']:
        from models.mlp import (get_cut_dim, get_z_dim)
    set_seed(args.seed)
    config, models, optimizers, schedulers, criterion, train, test = setup_task(args)
    
    if args.missing_type == 'mcar':
        args.wandb_name = f"{args.method}_{args.task_name}_n{args.num_clients}_la{args.labeled_aligned_num}_lu{args.labeled_unaligned_num}_{args.missing_type}{int(args.p_mcar*100)}"
    elif args.missing_type == 'mar':
        args.wandb_name = f"{args.method}_{args.task_name}_n{args.num_clients}_la{args.labeled_aligned_num}_lu{args.labeled_unaligned_num}_{args.missing_type}{args.option}"
    elif args.missing_type == 'mnar':
        args.wandb_name = f"{args.method}_{args.task_name}_n{args.num_clients}_la{args.labeled_aligned_num}_lu{args.labeled_unaligned_num}_{args.missing_type}{int(args.p_miss*100)}"
                        
    if args.method in ['ours', 'ours_mnar']:
        args.wandb_name += f"_z{get_z_dim(args.task_name)}_cut{get_cut_dim(args.task_name)}_K{args.K}_{args.K_test}_{args.ours_agg}"
        
        if args.pretrain:
            args.wandb_name = "pretrain_" + args.wandb_name + f"_lr{int(config['lr']*100000)}_wd{int(config['weight_decay']*100000)}_bs{config['batch_size']}{args.add_wandb}"
            for param in models[0].disc.parameters():
                param.requires_grad = False
        else:
            if not args.label_only:
                args.wandb_name += f"_lr{args.lr_load}_wd{args.wd_load}_bs{args.bs_load}{args.wb_load}"
            try:
                if not args.label_only:
                    pretrained = torch.load("pretrain_models/pretrain_" + args.wandb_name + ".pth", map_location=args.device)
                    model_dict = models[0].state_dict()
                    pretrained = {k: v for k, v in pretrained.items() if k in model_dict and model_dict[k].shape == v.shape}
                    models[0].load_state_dict(pretrained, strict=False)
                args.wandb_name += f"_lr{int(config['lr']*100000)}_wd{int(config['weight_decay']*100000)}_bs{config['batch_size']}{args.add_wandb}_s{args.seed}"
            except FileNotFoundError:
                print("[WARN] No pretrained checkpoint found.")

            if not args.label_only:
                for name, param in models[0].named_parameters():
                    if "disc" not in name:
                        param.requires_grad = False
                    else:
                        param.requires_grad = True

    else:
        args.wandb_name += f"{args.add_wandb}"
    if args.use_wandb:
        init_wandb(args, config)

    ds_la, ds_lu, ds_ua, ds_uu = create_train_datasets_with_cache(args)
    dict_tla, dict_tlu= create_test_configs_with_cache(args)
    
    train_unlabeled_loader, train_labeled_loader = get_ours_loaders(ds_la, ds_lu, ds_ua, ds_uu, config["batch_size"], args)
    test_loader = get_ours_test_loaders(dict_tla, dict_tlu, config["batch_size"], args)

        
    print('Computing initial metrics...')
    
    if args.pretrain:
        train_metrics = test(train_unlabeled_loader, models, criterion, args, is_train_data=True)
    else:
        train_metrics = test(train_labeled_loader, models, criterion, args, is_train_data=True)
    test_metrics = test(test_loader[0], models, criterion, args)

    if args.use_wandb:
        metrics = {'train_loss': train_metrics['train_loss'], 'train_acc': train_metrics['train_acc'], 'test_loss': test_metrics['test_loss'], 'test_acc': test_metrics['test_acc']}
        wandb.log(metrics)

    for epoch in range(config["num_epochs"]):
        print_exp_info(args, config, epoch)
        if args.pretrain:
            train_metrics = train(train_unlabeled_loader, models, optimizers, criterion, args)
        else:
            train_metrics = train(train_labeled_loader, models, optimizers, criterion, args)
        test_metrics = test(test_loader[0], models, criterion, args)
        
        for scheduler in schedulers:
            scheduler.step()
        if args.use_wandb:
            metrics = {'train_loss': train_metrics['train_loss'], 'train_acc': train_metrics['train_acc'], 'test_loss': test_metrics['test_loss'], 'test_acc': test_metrics['test_acc']}
            wandb.log(metrics)

    if args.pretrain:
        torch.save(models[0].state_dict(), f"pretrain_models/{args.wandb_name}.pth")

    for i in range(8):
        test_metrics = test(test_loader[i], models, criterion, args, is_final=True)
        if args.use_wandb:
            metrics = {f'final_test_loss_{i}': test_metrics["final_test_loss"], f'final_test_acc_{i}': test_metrics["final_test_acc"]}
            wandb.log(metrics)
        print_str = f'(final_miss_test {i}) final_test_loss: {test_metrics["final_test_loss"]}'
        print_str += f' | final_test_acc: {test_metrics["final_test_acc"]}'
        print(print_str)
    
    if args.use_wandb:
        wandb.finish()


if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--task_name', default='fashionmnist') 
    parser.add_argument('--cuda_id', type=int, default=0)
    parser.add_argument('--wandb_name', help='Name of the run.')
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--num_clients', type=int, default=8)
    parser.add_argument('--labeled_aligned_num', type=int, default=100)
    parser.add_argument('--labeled_unaligned_num', type=int, default=400)
    parser.add_argument('--missing_type', default='mcar')
    parser.add_argument('--p_mcar', type=float, default=0.5)
    parser.add_argument('--p_miss', type=float, default=0.7)
    parser.add_argument('--option', type=int, default=0)
    parser.add_argument('--no_wandb', action='store_false', dest='use_wandb', help='Disable wandb logging.')
    parser.add_argument('--method', choices=['ours','ours_mnar'], required=True)
    parser.add_argument("--K", type=int, default=10, help="number of IS during training")
    parser.add_argument("--K_test", type=int, default=50)
    parser.add_argument("--pretrain", action='store_true')
    parser.add_argument("--gpu_order", type=int, default=0)
    parser.add_argument("--ngpu", type=int, default=1)
    parser.add_argument("--ours_agg", choices=['mean','sum','weighted'], default='mean')
    parser.add_argument("--lr_load", type=int, default=10)
    parser.add_argument("--wd_load", type=int, default=10)
    parser.add_argument("--bs_load", type=int, default=1024)
    parser.add_argument("--add_wandb", type=str, default='')
    parser.add_argument("--wb_load", type=str, default='')
    parser.add_argument('--label_only', action='store_true')
    args = parser.parse_args()
    
    args.project = 'vfl'
    args.device = torch.device(f'cuda:{args.cuda_id}' if torch.cuda.is_available() else 'cpu')

    main(args)

