import time
from sklearn.metrics import classification_report, accuracy_score, confusion_matrix
import torch
from torch import nn, optim
from tqdm import tqdm
from utils.backbone import get_model
import torch.nn.functional as F
import matplotlib.pyplot as plt
import torch.optim.lr_scheduler as lrs
from torch.optim import Adam, SGD
import os

def train_and_save(model, train_loader, test_loader, model_name, args, mode='ori'):
    model = model.to(args.device)
    model.train()
    criterion = nn.CrossEntropyLoss()
    
    print(train_loader.dataset.transform)
    if args.data_name == "cifar10":
        if args.model_name in ["ResNet18", "ResNet50"]:
            optimizer = SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4, nesterov=True)
            scheduler = lrs.CosineAnnealingLR(optimizer, T_max=280)
        elif args.model_name == "ViT":
            optimizer = AdamW(model.parameters(), lr=2e-5)
    
    elif args.data_name == "cifar100":
        if args.model_name in ["ResNet18", "ResNet50"]:
            optimizer = SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
            scheduler = lrs.CosineAnnealingLR(optimizer, T_max=280)
        elif args.model_name == "ViT":
            optimizer = AdamW(model.parameters(), lr=2e-5)

    for epoch in range(1, args.train_epochs + 1):
        running_loss = train(model, train_loader, criterion, optimizer, "cross", args.device)

        if epoch % 5 == 0:
            acc, val_loss= test(model, test_loader)
            print(f'epoch {epoch} loss: {running_loss:.4f}, acc: {acc}, val_loss: {val_loss:.4f}')
        if args.model_name in ["ResNet18", "ResNet50"]: scheduler.step()
        
        print(optimizer.param_groups[0]['lr'])

    if not os.path.exists('./checkpoints'): os.makedirs('./checkpoints')
    torch.save(model.state_dict(), f'./checkpoints/{model_name}_{mode}{args.rnd_seed}.pth')

def train(model, data_loader, criterion, optimizer, loss_mode, device='cpu'):
    running_loss = 0
    model.train()
    for step, (batch_x, batch_y) in enumerate(tqdm(data_loader)):

        batch_x = batch_x.to(device)
        batch_y = batch_y.to(device)
        optimizer.zero_grad()
        output = model(batch_x)

        if loss_mode == "mse":
            loss = criterion(output, batch_y)
        elif loss_mode == "cross":
            loss = criterion(output, batch_y)
        elif loss_mode == 'neg_grad':
            loss = -criterion(output, batch_y)

        loss.backward()
        optimizer.step()
        running_loss += loss

    return running_loss

def test(model, loader):
    # return accuracy
    model.eval()
    correct = 0
    total = 0
    loss = 0
    with torch.no_grad():
        for data in loader:
            images, labels = data
            images = images.cuda()
            labels = labels.cuda()
            outputs = model(images)
            loss += F.cross_entropy(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total, loss / len(loader)


def eval(model, data_loader, batch_size=64, mode='backdoor', print_perform=False, device='cpu', name=''):
    model.eval()  # switch to eval status

    y_true = []
    y_predict = []
    for step, (batch_x, batch_y) in enumerate(data_loader):

        batch_x = batch_x.to(device)
        batch_y = batch_y.to(device)
        batch_y_predict = model(batch_x)
        if mode == 'pruned':
            batch_y_predict = batch_y_predict[:, 0:10]
        elif type(mode) == int:
            batch_y_predicts = torch.chunk(batch_y_predict, mode, dim=-1)
            batch_y_predict = sum(batch_y_predicts) / mode


        batch_y_predict = torch.argmax(batch_y_predict, dim=1)
        # batch_y = torch.argmax(batch_y, dim=1)
        y_predict.append(batch_y_predict)
        y_true.append(batch_y)

    y_true = torch.cat(y_true, 0)
    y_predict = torch.cat(y_predict, 0)

    num_hits = (y_true == y_predict).float().sum()
    acc = num_hits / y_true.shape[0]
    # print()

    if print_perform and mode != 'backdoor' and mode != 'widen' and mode != 'pruned':
        print(classification_report(y_true.cpu(), y_predict.cpu(), target_names=data_loader.dataset.classes, digits=4))
    if print_perform and mode == 'widen':
        class_name = data_loader.dataset.classes.append('extra class')
        print(classification_report(y_true.cpu(), y_predict.cpu(), target_names=class_name, digits=4))
        C = confusion_matrix(y_true.cpu(), y_predict.cpu(), labels=class_name)
        plt.matshow(C, cmap=plt.cm.Reds)
        plt.ylabel('True Label')
        plt.xlabel('Pred Label')
        plt.show()
    if print_perform and mode == 'pruned':
        class_name = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]#['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
        C = confusion_matrix(y_true.cpu(), y_predict.cpu(), labels=class_name)
        plt.matshow(C, cmap=plt.cm.Reds)
        plt.ylabel('True Label')
        plt.xlabel('Pred Label')
        plt.title('{} confusion matrix'.format(name), loc='center')
        plt.show()

    return accuracy_score(y_true.cpu(), y_predict.cpu()), acc
