import torch
import numpy as np

import torch.nn.functional as F
import torch.distributed as dist
import sklearn.metrics as metrics

from tqdm import tqdm
from datetime import datetime
from torch.nn.utils import clip_grad_norm_


def cal_loss(pred, gold, smoothing=False, ignore_index=255):
    ''' Calculate cross entropy loss, apply label smoothing if needed. '''

    gold = gold.contiguous().view(-1)

    if smoothing:
        eps = 0.2
        n_class = pred.size(1)

        one_hot = torch.zeros_like(pred).scatter(1, gold.view(-1, 1), 1)
        one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
        log_prb = F.log_softmax(pred, dim=1)

        loss = -(one_hot * log_prb).sum(dim=1).mean()
    else:
        loss = F.cross_entropy(pred, gold, reduction='mean', ignore_index=ignore_index)

    return loss

def single_test(cfg,fabric, data, model,epoch):
    loss_fn = fabric.to_device(torch.nn.CrossEntropyLoss())
    num_classes = cfg.data.classes
    class_correct = torch.zeros(num_classes, device=fabric.device)
    class_total = torch.zeros(num_classes, device=fabric.device)
    Trainer = model
    val_correct = 0
    val_total = 0
    val_loss = 0.0
    count = 0.0
    Trainer.eval()
    test_pred = []
    test_true = []
    with torch.no_grad():
        pbar = tqdm(total=len(data), desc=f'Epoch {epoch+1}/{cfg.opt.epochs}')
        for gs in data:
            torch.cuda.empty_cache()
            # input = gs['zoom0']
            # input = gs['sample']['xyz']
            input = gs
            batch_size = cfg.opt.batch_size
            logits = Trainer(input)
            # loss = cal_loss(logits, gs['label'])
            preds = logits.argmax(1)
            val_correct += (preds == gs['label']).sum().item()
            val_total += gs['label'].shape[0]
            val_loss += loss_fn(logits, gs['label']).item()
            for c in range(num_classes):
                class_mask = (gs['label'] == c)
                class_total[c] += class_mask.sum()
                class_correct[c] += ((preds == gs['label']) & class_mask).sum()
            count += batch_size
            # test_loss += loss.item() * batch_size
            # test_true.append(gs['label'].cpu().numpy())
            # test_pred.append(preds.detach().cpu().numpy())

            pbar.update(1)
            pbar.set_postfix(loss=f'{val_loss:.4f}')
        pbar.close()
    fabric.barrier()
    # test_true = np.concatenate(test_true)
    # test_pred = np.concatenate(test_pred)
    # test_acc = metrics.accuracy_score(test_true, test_pred)
    # avg_per_class_acc = metrics.balanced_accuracy_score(test_true, test_pred)
    test_acc = fabric.all_reduce(val_correct, reduce_op="sum").item()
    test_total = fabric.all_reduce(val_total, reduce_op="sum").item()
    test_loss = fabric.all_reduce(val_loss, reduce_op="sum").item()

    class_correct = fabric.all_reduce(class_correct, reduce_op="sum")
    class_total = fabric.all_reduce(class_total, reduce_op="sum")
    LEN = fabric.all_reduce(len(data), reduce_op="sum").item()

    test_acc = test_acc / test_total
    class_acc = (class_correct / class_total.clamp(min=1)).tolist()
    avg_per_class_acc = sum(class_acc) / num_classes
    test_loss = test_loss/LEN

    return test_acc, avg_per_class_acc, test_loss, count

def zoom_singleclass(cfg,fabric,scheduler,ema,Trainer,optimizer,io,writer,train_dataloader,test_dataloader):
    current_time = datetime.now().strftime('%Y%m%d_%H%M%S')
    best_test_acc = 0.0
    loss_fn = fabric.to_device(torch.nn.CrossEntropyLoss())
    for epoch in range((cfg.opt.epochs)):
        #train
        Trainer.train()
        torch.cuda.empty_cache()
        train_dataloader.sampler.set_epoch(epoch) 
        train_loss = 0.0
        count = 0.0
        train_correct = 0
        train_loss = 0
        train_pred = []
        train_true = []

        pbar = tqdm(total=len(train_dataloader), desc=f'Epoch {epoch+1}/{cfg.opt.epochs}')
        for gs in train_dataloader:
            input = gs
            batch_size = cfg.opt.batch_size
            # VIEW(cfg, input)
            logits = Trainer(input)
            loss = loss_fn(logits, gs['label'])
            assert not loss.isnan(), "loss is nan"
            fabric.backward(loss)
            clip_grad_norm_(Trainer.parameters(), 1)
            optimizer.step()
            optimizer.zero_grad()
            ema.update()
            count += batch_size
            train_loss += loss.item() * batch_size
            train_correct += (logits.argmax(1) == gs['label']).sum().item()

            pbar.update(1)
            pbar.set_postfix(loss=f'{loss:.4f}',)
            if dist.is_initialized():
                dist.barrier()
        pbar.close()
        scheduler.step()
        fabric.barrier()
        torch.cuda.empty_cache()

        # train_true = np.concatenate(train_true).ravel()
        # train_pred = np.concatenate(train_pred).ravel()
        # np.set_printoptions(threshold=np.inf)

        # all_classes = list(range(cfg.data.classes))
        # train_acc = metrics.accuracy_score(train_true, train_pred)
        # avg_train_acc = metrics.balanced_accuracy_score(train_true, train_pred)
        
        # train_acc = fabric.all_reduce(train_correct, reduce_op="sum").item() / ((batch_size + 1) * cfg.opt.batch_size)
        # avg_train_acc = metrics.balanced_accuracy_score(train_true, train_pred)

        train_acc = fabric.all_reduce(train_correct, reduce_op="sum").item()
        train_total = fabric.all_reduce(count, reduce_op="sum").item()
        train_acc = train_acc / train_total



        writer.add_scalar('train_acc', train_acc, epoch)
        # writer.add_scalar('avg_train_acc', avg_train_acc, epoch)

        outstr = 'Train %d, loss: %.6f, train acc: %.6f, ' % (epoch,
                                                                train_loss*1.0/count,
                                                                train_acc,                                                                   
                                                                )
        io.cprint(outstr)


        #test
        if epoch%1==0:
            torch.cuda.empty_cache()

            test_acc, avg_per_class_acc, test_loss, count = single_test(cfg,fabric,test_dataloader,Trainer,epoch)

            test_acc_all = fabric.all_gather(test_acc)
            avg_test_acc_all = fabric.all_gather(avg_per_class_acc)
            if fabric.is_global_zero:
                mean_acc_all = test_acc_all.mean().item()
                mean_avg_acc_all = avg_test_acc_all.mean().item()

                writer.add_scalar('test_acc', mean_acc_all, epoch)
                writer.add_scalar('avg_test_acc', mean_avg_acc_all, epoch)

                outstr = 'Test %d, loss: %.6f, test acc: %.6f, test avg acc: %.6f' % (epoch,
                                                                                    test_loss*1.0/count,
                                                                                    mean_acc_all,
                                                                                    mean_avg_acc_all)
                io.cprint(outstr)
                # save model
                
                if mean_acc_all >= best_test_acc:
                    best_test_acc = mean_acc_all
                    torch.save(Trainer.state_dict(), 'checkpoints/%s/model_view_%s.t7' % (cfg.model.name, current_time) )

        if epoch == cfg.opt.epochs:
            torch.save(Trainer.state_dict(), 'checkpoints/%s/last_epoch_model_%s.t7' % (cfg.model.name, current_time))

        writer.close()