import os
import time
import copy
import numpy as np
from torchvision import datasets, transforms
import csv
import torch
from utils.options import args_parser

from src.server import Server
from models.Nets import CNNCifar, MLP
from models.resnet import resnet18, resnet8, resnet10, resnet6
from utils.util import set_logger
import random
import torchvision.models as torch_model
from utils.log_utils import Logger, client_sampling, VariableMonitor


def exp_parameter(args):
    print(f'Communication Rounds: {args.epochs}')
    print(f'Client Number: {args.num_users}')
    print(f'Local Epochs: {args.local_ep}')
    print(f'Local Batch Size: {args.local_bs}')
    print(f'Learning Rate: {args.lr}')
    print(f'Policy: {args.policy}')
    print(f'Save model: {args.save_model}')
    print(f'Communication Rounds: {args.epochs}')
    print(f'noniid: {args.noniid}')
    print(f'alpha: {args.alpha}')
    print(f'seed: {args.seed}')

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

def train(args, logger=None):
    device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')
    accList = []
    for t in range(args.repeat):    # repeat args.repeat times
        if args.seed == -1: # if seed is not not specified
            args.seed = np.random.randint(0,10000)  # random select a seed
        setup_seed(args.seed)

        logFilePath = f'./log/{args.dataset}_po{args.policy}_user{args.num_users}_round{args.epochs}_epoch({args.local_ep}_{args.contrastive_ep})' \
                      f'_lr({args.lr}_{args.mlp_lr})_proj({args.proj})_mlp{(args.mlp)}_{args.noniid}_{args.alpha}_trainnum{args.train_num}_gcls({args.personalized_classifier})' \
                      f'_seed{args.seed}/gpnum({args.personalized_g_prompt}_{args.prompt_num_tokens_g})_ppnum{args.prompt_num_tokens_p}_lamda({args.lamda1})/'

        tensorboardLogger = Logger(logFilePath)
        logger = set_logger(logFilePath + 'textlog.log')

        logger.info(f'args: {args}')

        args.root_file = logFilePath


        # local_model = CNNCifar(args).to(device)
        if args.dataset == 'cifar':
            if args.model == 'cnn':
                local_model = CNNCifar(args).to(device)
            elif args.model == 'resnet8':
                local_model = resnet8(num_labels=args.num_classes).to(device)
            elif args.model == 'vgg':
                local_model = torch_model.vgg11().to(device)
            elif args.model == 'lenet':
                local_model = torch_model.googlenet().to(device)
            elif args.model == 'mobilenet':
                local_model = torch_model.mobilenet_v2().to(device)
            else:
                raise NotImplementedError
        elif args.dataset == 'mnist' or args.dataset == 'fmnist':
            # local_model = MLP(dim_in=28 * 28, dim_hidden=64, dim_out=10).to(device)
            args.num_classes = 10
            if args.model == 'resnet6':
                local_model = resnet6(num_labels=args.num_classes, input_dim=1).to(device)
            else:
                raise NotImplementedError
        elif args.dataset == 'cifar-100':
            args.num_classes = 100
            if args.model == 'cnn':
                local_model = CNNCifar(args).to(device)
            elif args.model == 'resnet8':
                local_model = resnet8(num_labels=args.feature_dim).to(device)
            elif args.model == 'resnet10':
                local_model = resnet10(num_labels=args.feature_dim).to(device)
            elif args.model == 'resnet18':
                local_model = resnet18(num_labels=args.feature_dim).to(device)
            elif args.model == 'lenet':
                local_model = torch_model.googlenet().to(device)
            elif args.model == 'mobilenet':
                local_model = torch_model.mobilenet_v2().to(device)
            else:
                raise NotImplementedError
        elif args.dataset == 'tinyimagenet':
            args.num_classes = 200
            if args.model == 'cnn':
                local_model = CNNCifar(args).to(device)
            elif args.model == 'resnet8':
                local_model = resnet8(num_labels=args.num_classes).to(device)
            elif args.model == 'resnet10':
                local_model = resnet10(num_labels=args.num_classes).to(device)
            elif args.model == 'resnet18':
                local_model = resnet18(num_labels=args.num_classes).to(device)
            elif args.model == 'lenet':
                local_model = torch_model.googlenet().to(device)
            elif args.model == 'mobilenet':
                local_model = torch_model.mobilenet_v2().to(device)
            else:
                raise NotImplementedError
        else:
            raise NotImplementedError
        print(f'Model Structure: {local_model}')
        logger.info(f'Model Structure: {local_model}')
        server = Server(device, local_model, args, logger=logger, tensorboardLogger=tensorboardLogger)
        server.train()
        print('Best:', server.best_accuracy)
        logger.info(f'Best: {server.best_accuracy}')
        accList.append(server.best_accuracy)
        args.seed = -1
    print(f'Repeat {args.repeat} times, mean:{np.mean(accList)}, std:{np.std(accList)}')




if __name__ == '__main__':
    args = args_parser()
    args.verbose = 0
    exp_parameter(args)
    train(args)


