import os
import math
import time
from tqdm import tqdm
import torch
import numpy as np
from utils import *
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

def train(args, train_loader, model, criterion, optimizer, epoch, weight=None):
    """
        Run one train epoch
    """
    if weight is None:
        weight = torch.ones(args.train_num).cuda()

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    acc = np.zeros(args.train_num)

    # switch to train mode
    model.train()

    end = time.time()

    num_iter = len(train_loader)
    if hasattr(args, 'match_num_iter_preselect'):
        if args.match_num_iter_preselect and epoch < args.start_epoch:
            num_iter = math.floor(args.train_num/args.batch_size)
            args.logger.info(f'Iterate train loader to match the number of iterations {num_iter}')
    if hasattr(args, 'match_num_iter'):
        if args.match_num_iter and (len(train_loader.dataset) < args.train_num):
            num_iter = math.floor(args.train_num/args.batch_size)
            args.logger.info(f'Iterate train loader to match the number of iterations {num_iter}')

    p_bar = tqdm(num_iter)
    i = 0
    data_iter = iter(train_loader)
    while i < num_iter:
        i = int(i)

        try:
            input, target, idx = data_iter.next()
        except:
            data_iter = iter(train_loader)
            input, target, idx = data_iter.next()


        # measure data loading time
        data_time.update(time.time() - end)

        target = target.cuda()
        input_var = input.cuda()
        target_var = target
        if args.half:
            input_var = input_var.half()

        # compute output
        output = model(input_var)
        loss = criterion(output, target_var)
        batch_weight = weight[idx.long()]
        if hasattr(args, 'subset_weight'):
            if args.subset_weight == 'minibatch':
                batch_weight = batch_weight / torch.sum(batch_weight) * len(batch_weight)
        loss = (loss * batch_weight).mean()  # (Note)

        # compute gradient and do SGD step
        optimizer.zero_grad()

        loss.backward()
        optimizer.step()

        output = output.float()
        loss = loss.float()
        # measure accuracy and record loss
        prec1 = accuracy(output.data, target)[0]
        losses.update(loss.item(), input.size(0))
        top1.update(prec1.item(), input.size(0))

        preds = torch.argmax(nn.Softmax(dim=1)(output), dim=1)
        acc[idx] += preds.eq(target_var).float().detach().cpu().numpy()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        p_bar.set_description("Train Epoch: {epoch}/{epochs:4}. Iter: {batch:4}/{iter:4}. LR: {lr:.4f}. Data: {data:.3f}s. Batch: {bt:.3f}s. Top1: {top:.4f}. ".format(
            epoch=epoch + 1,
            epochs=args.epochs,
            batch=i+1,
            iter=num_iter,
            lr=optimizer.param_groups[0]['lr'],
            data=data_time.avg,
            bt=batch_time.avg,
            top=top1.avg))
        p_bar.update()

        i += 1

    acc[acc>1] = 1

    return data_time.sum, batch_time.sum, acc


def validate(args, val_loader, model, criterion, weight=None, example_losses=False):
    """
    Run evaluation
    """
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    acc = np.zeros(len(val_loader.dataset))
    if example_losses:
        loss_per_example = torch.zeros(args.train_num).cuda()

    # switch to evaluate mode
    model.eval()

    end = time.time()
    with torch.no_grad():
        for i, (input, target, idx) in enumerate(val_loader):
            target = target.cuda()
            input_var = input.cuda()
            target_var = target.cuda()

            if args.half:
                input_var = input_var.half()

            # compute output
            output = model(input_var)
            loss = criterion(output, target_var)
            if example_losses:
                loss_per_example[idx] = loss
                loss = loss.mean()
            if weight is not None:
                loss = (loss * weight[idx.long()]).mean()

            output = output.float()
            loss = loss.float()

            # measure accuracy and record loss
            prec1 = accuracy(output.data, target)[0]
            losses.update(loss.item(), input.size(0))
            top1.update(prec1.item(), input.size(0))

            preds = torch.argmax(nn.Softmax(dim=1)(output), dim=1)
            acc[idx] += preds.eq(target_var).float().detach().cpu().numpy()

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

    args.logger.info(f' * Prec@1 {top1.avg:.3f}\tTime {batch_time.sum:.3f}')

    if example_losses:
        return top1.avg, losses.avg, acc, loss_per_example
    else:
        return top1.avg, losses.avg, acc


def validate_by_group(args, val_loader_list, model, criterion, weight=None, example_losses=False):
    """
    Run evaluation
    """
    losses = AverageMeter()
    batch_time = AverageMeter()
    if example_losses:
        loss_per_example = torch.zeros(args.train_num).cuda()

    # switch to evaluate mode
    model.eval()

    end = time.time()
    val_results = []
    with torch.no_grad():
        for i, val_loader in enumerate(val_loader_list):
            top1 = AverageMeter()
            for _, (input, target, idx) in enumerate(val_loader):
                target = target.cuda()
                input_var = input.cuda()
                target_var = target.cuda()

                if args.half:
                    input_var = input_var.half()

                # compute output
                output = model(input_var)
                loss = criterion(output, target_var)
                if example_losses:
                    loss_per_example[idx] = loss
                    loss = loss.mean()
                if weight is not None:
                    loss = (loss * weight[idx.long()]).mean()

                output = output.float()
                loss = loss.float()

                # measure accuracy and record loss
                prec1 = accuracy(output.data, target)[0]
                losses.update(loss.item(), input.size(0))
                top1.update(prec1.item(), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            val_results.append(top1.avg)
            args.logger.info(f' * Group {i} Prec@1 {top1.avg:.3f}\tTime {batch_time.sum:.3f}')

    avg = 0.7295099061522419 * val_results[0] + 0.0383733055265902 * val_results[1] + 0.01167883211678832 * val_results[2] + 0.22043795620437956 * val_results[3]
    worst = min(val_results)
    args.logger.info(f' * Average Prec@1 {avg:.3f}\tWorst Prec@1 {worst:.3f}')

    return avg, worst, losses.avg


def evaluate(args, dataloader, net, criterion):
    # Validation for CMNIST
    running_loss = 0.0
    all_losses = []
    correct = 0
    total = 0

    targets_s = dataloader.dataset.targets_all['spurious'].astype(int)
    targets_t = dataloader.dataset.targets_all['target'].astype(int)

    correct_by_groups = np.zeros([len(np.unique(targets_t)),
                                  len(np.unique(targets_s))])
    total_by_groups = np.zeros(correct_by_groups.shape)

    correct_indices = []
    net.cuda()
    net.eval()

    with torch.no_grad():
        for i, data in enumerate(tqdm(dataloader)):
            inputs, labels, data_ix = data
            inputs = inputs.cuda()
            labels = labels.cuda()

            labels_spurious = [targets_s[ix] for ix in data_ix]

            outputs = net(inputs)
            loss = criterion(outputs, labels)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            all_correct = (predicted == labels).detach().cpu()
            correct += all_correct.sum().item()
            running_loss += loss.item()
            all_losses.append(loss.detach().cpu().numpy())

            correct_indices.append(all_correct.numpy())

            for ix, s in enumerate(labels_spurious):
                y = labels.detach().cpu().numpy()[ix]
                correct_by_groups[int(y)][int(s)] += all_correct[ix].item()
                total_by_groups[int(y)][int(s)] += 1

    avg_acc = correct / total
    group_acc = correct_by_groups / total_by_groups
    min_acc = np.amin(group_acc)
    args.logger.info(f' * Average Prec@1 {avg_acc:.3f}\tWorst Prec@1 {min_acc:.3f}')
                
    return running_loss, avg_acc, min_acc


def predictions(args, loader, model):
    """
    Get predictions
    """
    batch_time = AverageMeter()

    # switch to evaluate mode
    model.eval()

    preds = torch.zeros(args.train_num, args.class_num).cuda()
    labels = torch.zeros(args.train_num, dtype=torch.int)
    end = time.time()
    with torch.no_grad():
        for i, (input, target, idx) in enumerate(loader):
            input_var = input.cuda()

            if args.half:
                input_var = input_var.half()

            preds[idx, :] = nn.Softmax(dim=1)(model(input_var))
            labels[idx] = target.int()

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

    return preds.cpu().data.numpy(), labels.cpu().data.numpy()


def get_losses_and_preds(args, loader, model, criterion):
    """
    Get predictions
    """
    # switch to evaluate mode
    model.eval()

    preds = torch.zeros(args.train_num, args.class_num).cuda()
    losses = torch.zeros(args.train_num).cuda()
    with torch.no_grad():
        for i, (input, target, idx) in enumerate(loader):
            input_var = input.cuda()
            target_var = target.cuda()

            if args.half:
                input_var = input_var.half()
            output = model(input_var)
            loss = criterion(output, target_var)

            preds[idx, :] = nn.Softmax(dim=1)(output)
            losses[idx] = loss

    return losses.cpu().data.numpy(), preds.cpu().data.numpy()
