import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
import torchvision
from .utils import progress_bar


def get_n_cls(opt):
    if opt.dataset in ['CIFAR10', "fashion_mnist", "MNIST"]:
        n_cls = 10
    elif opt.dataset == 'CIFAR100':
        n_cls = 100
    elif opt.dataset == 'emnist':
        n_cls = 26
    elif opt.dataset == 'TinyImage':
        n_cls = 200
    return n_cls

def get_model(opt):
    gpu = torch.device('cuda:{}'.format(opt.gpu_id))
    n_agents = opt.n_agents
    n_cls = get_n_cls(opt)

    if opt.model == 'ResNet':
        f = []
        for i in range(n_agents):
            model = torchvision.models.resnet18()
            # pdb.set_trace()
            if opt.dataset == 'TinyImage':
                model.avgpool = nn.AdaptiveAvgPool2d(1)

            model.fc = torch.nn.Linear(512, n_cls, bias = True)
            f.append(model)
        global_f = torchvision.models.resnet18()
        if opt.dataset == 'TinyImage':
            global_f.avgpool = nn.AdaptiveAvgPool2d(1)
        global_f.fc = torch.nn.Linear(512, n_cls, bias = True)

    return f, global_f


def get_optimizer(model, lr, opt):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.99), eps=1e-8, weight_decay=5e-4)

    return optimizer

def train(epoch, model, loss_fn, trainloader, optimizer, device):
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    for params in model.parameters():
        params.requires_grad = True
    model.to(device)
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        # pdb.set_trace()
        loss = loss_fn(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
    # pdb.set_trace()
    for params in model.parameters():
        params.requires_grad = False
    model.eval()
    return model
    
def test(epoch, model, loss_fn, testloader, optimizer, device):
    global best_acc
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = loss_fn(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                         % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))
    # Save checkpoint.
    acc = 100.*correct/total
    total_loss = test_loss/total

    return acc, total_loss
