import os
import argparse
import random
import numpy as np
import sys

import torch
from dataset import MNISTPatch, FashionMNISTPatch, CIFAR10Patch, StarCraftMNISTPatch
from solver import Solver

import wandb

def prepare_wandb(args):
    args.run_name = f'{args.model_setup}R{args.repeat}G{args.n_gossip}_{args.drop_mode}_Ndevice{args.n_device}_{args.note}_seed{args.seed}'
    if args.wandb or args.sweep:
        wandb.init(project=args.project_name if not args.sweep else None,
                   entity=args.wandb_entity if not args.sweep else None,
                   name=args.run_name,
                   config=vars(args))
        wandb.run.log_code()

def prepare_dataloader(args, kwargs):
    if args.dataset == 'mnist':
        args.n_class = 10
        args.image_shape = (28, 28, 1)
        train_val_data = MNISTPatch(args.n_device, root=args.data_dir, train=True)
        test_set = MNISTPatch(args.n_device, root=args.data_dir, train=False)
    elif args.dataset == 'fmnist':
        args.n_class = 10
        args.image_shape = (28, 28, 1)
        train_val_data = FashionMNISTPatch(args.n_device, root=args.data_dir, train=True)
        test_set = FashionMNISTPatch(args.n_device, root=args.data_dir, train=False)
    elif args.dataset == 'cifar10':
        args.n_class = 10
        args.image_shape = (32, 32, 3)
        train_val_data = CIFAR10Patch(args.n_device, root=args.data_dir, train=True)
        test_set = CIFAR10Patch(args.n_device, root=args.data_dir, train=False)
    elif args.dataset == 'starmnist':
        args.n_class = 10
        args.image_shape = (28, 28, 1)
        train_val_data = StarCraftMNISTPatch(args.n_device, root=args.data_dir, train=True)
        test_set = StarCraftMNISTPatch(args.n_device, root=args.data_dir, train=False)
    else:
        raise ValueError('dataset not supported')

    train_size = int(0.8 * len(train_val_data))
    val_size = len(train_val_data) -  train_size
    train_set, val_set= torch.utils.data.random_split(train_val_data, [train_size, val_size],
                                                      generator=torch.Generator().manual_seed(args.seed))
    train_loader = torch.utils.data.DataLoader(train_set,
                                         batch_size=args.batch_size,
                                         shuffle=True, drop_last=True, **kwargs)
    val_loader = torch.utils.data.DataLoader(val_set,
                                       batch_size=args.batch_size,
                                       shuffle=False, drop_last=True, **kwargs)

    test_loader = torch.utils.data.DataLoader(test_set,
                                        batch_size=args.batch_size,
                                        shuffle=False, drop_last=True, **kwargs)

    print('training set size: ', len(train_set))
    print('validation set size: ', len(val_set))
    print('test set size: ', len(test_set))
    return train_loader, val_loader, test_loader



if __name__ == "__main__":

    parser = argparse.ArgumentParser(description='Learning')

    # ============= basics ============= #
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--device', type=int, default=0)
    parser.add_argument('--seed', type=int, default=None)
    parser.add_argument('--dataset', type=str, default='mnist')

    # ============= setup ============= #
    parser.add_argument('--n_device', type=int, default=16)
    parser.add_argument('--drop_mode', type=str, default='comm', choices=['device', 'comm', 'device_ring', 'comm_ring'])
    parser.add_argument('--graph_type', type=str, default='uni')
    parser.add_argument('--drop_rate_train',type=float,default=0.0)
    parser.add_argument('--rgg_radius',type=float,default=1)

    # ============= model ============= #
    parser.add_argument('--model_setup',type=str, default='MVFL',choices=['VFL','MVFL','DeepMVFL','DeepMVFL_unconstrain'])
    parser.add_argument('--repeat', type=int, default=1, help='number of times to repeat the MVFL layer (Deep MVFL)')
    parser.add_argument('--gossip_mode', type=str, default='gm')
    parser.add_argument('--n_gossip', type=int, default=0, help='number of gossip communication')
    parser.add_argument('--activation',type=str,default='relu')
    parser.add_argument('--k_mvfl', type=int, default=16)

    # ============= training ========== #
    parser.add_argument('--epochs', type=int, default=100)
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--batch_size', type=int, default=64)

    # ============= logging =========== #
    parser.add_argument('--data_dir', type=str, default='./data')
    parser.add_argument('--save_dir', default='./saved')
    parser.add_argument('--no_save', action='store_true', default=False)
    parser.add_argument('--note', default='')
    # wandb
    parser.add_argument('--no_wandb', action='store_true', default=False)
    parser.add_argument('--project_name', default='InternetLearningTest')
    parser.add_argument('--wandb_entity', default='')
    parser.add_argument('--sweep', action='store_true', default=False, help='tells us whether this is being called as part of WandB sweep or not')

    args = parser.parse_args()
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    args.device = torch.device(f'cuda:{args.device}' if args.cuda else 'cpu')
    args.save = not args.no_save
    args.wandb = not args.no_wandb
    kwargs = {'num_workers': 1, 'pin_memory': False} if args.cuda else {}
    # ======================== #
    #         randomness       #
    # ======================== #
    seed = args.seed
    if args.seed is not None:
        os.environ['PYTHONHASHSEED'] = str(args.seed)
        random.seed(args.seed)
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)
        torch.backends.cudnn.benchmark = False
    else:
        print('Randomness is not controlled.\n'*5)
        args.seed = np.random.randint(1000)


    # ======================== #
    #         misc check       #
    # ======================== #
    # if args.model_setup == 'VFL' and args.repeat > 0:
    #     # skip sweep
    #     sys.exit(0)
    # if args.model_setup == 'VFL' and args.n_gossip > 0:
    #     # skip sweep
    #     sys.exit(0)
    if args.n_device == 16:
        # 49 -> 16
        # 16 -> 4
        # 4*16 -> 4
        args.d_inter_init = 16
        args.d_inter = 4
    elif args.n_device == 4:
        # 196 -> 64
        # 64 -> 16
        # 16*4 -> 16
        args.d_inter_init = 64
        args.d_inter = 16
    elif args.n_device == 49:
        # 16 -> 4
        # 4 -> 2
        # 2*49 -> 2
        args.d_inter_init = 4
        args.d_inter = 2
    else:
        raise ValueError('n_device not supported')
    # ======================== #
    #         log              #
    # ======================== #
    prepare_wandb(args)
    args.save_dir = f'{args.save_dir}/{args.dataset}/{args.run_name}'
    if not os.path.exists(args.save_dir) and args.save:
        os.makedirs(args.save_dir)
    # ======================== #
    #         data             #
    # ======================== #
    print('Loading data...')
    train_loader, val_loader, test_loader = prepare_dataloader(args, kwargs)

    # ======================== #
    #         training         #
    # ======================== #
    solver = Solver(train_loader, val_loader, test_loader, args)
    solver.train_and_test()