import os

from torch.utils.data import DataLoader
from tqdm import tqdm

import wandb
from dataset import load_dataset
from mix import *
from sample import uniform_sample
from utils import (AverageMeter, GroupWeightedLoss, accuracy, get_output,
                   load_model)


def log_and_save(args, model, optimizer, step, train_loss, avg, worst, test_avg, 
                 test_worst, cluster_labels, best_stats, 
                 save_name, save_best=False, save_all=False):
    
    is_best = False
    if save_best:
        # save checkpoint if best
        if worst > best_stats['best_worst_val']:
            args.logger.info(f'Best model found at {args.save_unit} {step} with val worst {worst:.3f} and test worst {test_worst:.3f}')
            is_best = True
            best_stats['best_worst_val'] = worst
            best_stats['best_worst_test'] = test_worst
            best_stats['best_avg_val'] = avg
            best_stats['best_avg_test'] = test_avg

        # log to wandb
        if args.use_wandb:
            wandb.log({f'{args.save_unit}': step, 'train_loss': train_loss, 'val_avg': avg, 
                    'val_worst': worst, 'test_avg': test_avg, 'test_worst': test_worst, 
                    'best_val_worst': best_stats['best_worst_val'], 'best_test_worst': best_stats['best_worst_test'], 
                    'best_val_avg': best_stats['best_avg_val'], 'best_test_avg': best_stats['best_avg_test']})

    if is_best or save_all:
        if is_best:
            save_name = os.path.join(args.save_dir, 'best_model.pt')
        elif args.sep_infer:
            if save_all and (step in args.infer_steps):
                save_name = os.path.join(save_name, '{}{}/checkpoint_seed{}.pt'.format(args.save_unit, step, args.seed))
                os.makedirs(os.path.dirname(save_name), exist_ok=True)
            else:
                return best_stats
        elif save_all:
            save_name = os.path.join(save_name, '{}{}/checkpoint_seed{}.pt'.format(args.save_unit, step, args.seed))
            os.makedirs(os.path.dirname(save_name), exist_ok=True)
        else:
            return best_stats

        if is_best:
            args.logger.info(f'Saving model at {args.save_unit} {step}\n')
        else:
            args.logger.info(f'Saving model to {save_name}')

        # save checkpoint
        torch.save({
            'model_state_dict': model.state_dict(),
            'args': args,
            f'{args.save_unit}': step,
            'best_stats': best_stats,
            'test_avg': test_avg,
            'test_worst': test_worst,
            'val_avg': avg,
            'val_worst': worst,
            'optimizer_state_dict': optimizer.state_dict(),
            'cluster_labels': cluster_labels,
        }, save_name)

    return best_stats


def train_and_validate(
        args, model, optimizer, train_criterion, val_criterion, train_loader, 
        val_loader, test_loader, epoch, cluster_labels=None, best_stats={}, max_iter=-1, save_best=True, 
        pretrain=False, save_all=False, save_name=None, scheduler=None):
    
    losses = AverageMeter()
    num_groups_per_class = args.num_groups // args.num_classes

    if args.weight_loss and not pretrain:
        # compute cluster weights as the inverse of cluster size based on cluster labels
        cluster_weights = torch.zeros(np.unique(cluster_labels).shape[0])
        for i in range(cluster_weights.shape[0]):
            cluster_weights[i] = (cluster_labels==i).sum()

    if args.infer_loss_diff and pretrain:
        max_loss_diff = args.best_loss_diff * np.ones(args.num_classes)
        max_loss_diff_step = args.infer_steps

    # count number of samples trained in each group
    group_counts = np.zeros(args.num_groups).astype(int)

    for batch_idx, (data, target, index) in enumerate(train_loader):

        if args.infer_loss_thresh > 0. and pretrain:
            max_min_loss = 0.0
        
        if max_iter > 0 and batch_idx >= max_iter:
            args.logger.info(f'Breaking after {max_iter} iterations')
            break

        model.train()
        if args.arch == 'bert' and args.optimizer == 'adamw':
            model.zero_grad()

        data, target = data.to(args.device), target.to(args.device)

        if args.balance_classes_infer and (args.infer == 'none' or pretrain):
            # print number of samples in each class
            args.logger.info('Train class counts: {}'.format(np.unique(target.cpu().numpy(), return_counts=True)[1]))

        if args.weighted_sampler and not pretrain:
            # print number of samples in each group with group labels
            for group in range(args.num_groups):
                group_counts[group] += np.sum(np.array(train_loader.dataset.group_array)[index] == group)
            
        if not pretrain:
            # add inter-cluster and intra-cluster mixup to data
            if args.mixup != 'none':
                if args.mixup_sample_by_cluster:
                    p = 1 / cluster_weights[cluster_labels[index]] ** args.mixup_sample_by_cluster_power
                    p = p / p.sum()
                    p = p.numpy()
                else:
                    p = None

                if args.mixup == 'mixup':
                    data, target_a, target_b, lam = mixup_data(args, data, target, args.mixup_alpha, args.device, p=p)
                elif args.mixup == 'cutmix':
                    data, target_a, target_b, lam = cutmix_data(args, data, target, args.mixup_alpha, args.device, p=p)

        output = get_output(model, data, target, args)

        # compute loss
        if args.mixup != 'none' and not pretrain:
            loss = mixup_criterion(train_criterion, output, target_a, target_b, lam)
        else:
            loss = train_criterion(output, target)

        # compute group weighted loss
        if (args.train == 'group_dro' and not pretrain) or (args.class_dro_infer and pretrain):
            loss = args.group_weighted_loss(loss, cluster_labels[index])
        elif not pretrain:
            if args.weight_loss:
                batch_weight = 1 / cluster_weights[cluster_labels[index]] ** args.weight_loss_power
                batch_weight = batch_weight / batch_weight.sum() * batch_weight.shape[0]
                batch_weight = batch_weight.to(args.device)
                loss = loss * batch_weight
                loss = loss.mean()
            else:
                loss = loss.mean()
        else:
            loss = loss.mean()

        if args.arch == 'bert' and args.optimizer == 'adamw':
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
            if scheduler is not None:
                scheduler.step()
            optimizer.step()
            model.zero_grad()
        else:
            # zero the parameter gradients and forward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        if args.train == 'group_dro' and args.use_wandb and (not pretrain or args.class_dro_infer):
            # log the weight of each group in wandb
            for i in range(len(args.group_weighted_loss.group_weights)):
                wandb.log({f'group{i}_weight': args.group_weighted_loss.group_weights[i].item()})

        losses.update(loss.item(), data.size(0))

        # args.logger.info batch stats
        if (args.save_unit == 'batch' or args.infer_loss_thresh > 0. or args.infer_loss_diff) and (pretrain or args.infer=='none'):
            args.logger.info('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, (batch_idx+1), len(train_loader),
                100. * (batch_idx+1) / len(train_loader), loss.item()))

            # log accuracy for each group and worst group in wandb at the same step        
            args.logger.info('Starting validation by group for val...')
            avg, worst, val_results, val_losses = validate_by_group(args, val_loader, model, val_criterion)

            # log accuracy for each group and worst group in wandb if use wandb
            if args.use_wandb and (save_best or args.infer_loss_thresh > 0. or args.infer_loss_diff):
                for i in range(len(val_results)):
                    wandb.log({f'val_group{i}_acc': val_results[i]})

                for i in range(len(val_losses)):
                    wandb.log({f'val_group{i}_loss': val_losses[i]})

            if args.infer_loss_diff and (pretrain or args.infer=='none'):
                best_step = False
                min_loss_diff = 1e10
                # log loss difference for groups in the same class
                for i in range(args.num_classes):
                    loss_diff = abs(val_losses[i*args.num_classes] - val_losses[i*args.num_classes+1])
                    if args.sep_infer:
                        if loss_diff > max_loss_diff[i]:
                            max_loss_diff[i] = loss_diff
                            max_loss_diff_step[i] = int(epoch*len(train_loader)+batch_idx+1)
                            best_step = True
                        if args.use_wandb:
                            wandb.log({f'class{i}_max_loss_diff': max_loss_diff[i],
                                    f'class{i}_max_loss_diff_step': max_loss_diff_step[i],
                                    'step': epoch*len(train_loader)+batch_idx+1})
                    elif loss_diff < min_loss_diff:
                        min_loss_diff = loss_diff
                if not args.sep_infer:
                    args.logger.info('Loss difference at step {}: {}'.format(epoch*len(train_loader)+batch_idx+1, min_loss_diff))
                    if min_loss_diff > max_loss_diff:
                        max_loss_diff = min_loss_diff
                        max_loss_diff_step = int(epoch*len(train_loader)+batch_idx+1)
                        best_step = True
                    # log max loss difference for each class in wandb
                    wandb.log({'max_loss_diff': max_loss_diff,
                               'max_loss_diff_step': max_loss_diff_step,
                               'step': epoch*len(train_loader)+batch_idx+1})

                if best_step:
                    args.logger.info('Updating infer_step to {}...'.format(max_loss_diff_step))
                    args.infer_steps = max_loss_diff_step

                    args.logger.info('Saving model at step {}...'.format(epoch*len(train_loader)+batch_idx+1))
                    save_name_ = os.path.join(args.checkpoint_path, '{}{}/checkpoint_seed{}.pt'.format(args.save_unit, epoch*len(train_loader)+batch_idx+1, args.seed))
                    os.makedirs(os.path.dirname(save_name_), exist_ok=True)

                    # save checkpoint
                    torch.save({
                        'model_state_dict': model.state_dict(),
                        'args': args,
                        f'{args.save_unit}': max_loss_diff_step,
                    }, save_name_)

            args.logger.info('Starting validation by group for test...')
            test_avg, test_worst, test_results, _ = validate_by_group(args, test_loader, model, val_criterion)

            # log accuracy for each group and worst group in wandb if use wandb
            if args.use_wandb and save_best:
                for i in range(len(test_results)):
                    wandb.log({f'test_group{i}_acc': test_results[i]}, step=epoch*len(train_loader)+batch_idx+1)

            args.logger.info('Val Avg: {:.4f}, Val Worst: {:.4f}, Test Avg: {:.4f}, Test Worst: {:.4f}'.format(
            avg, worst, test_avg, test_worst))

            # save model 
            best_stats = log_and_save(args, model, optimizer, batch_idx+1, loss.item(), avg, worst, test_avg, 
                         test_worst, cluster_labels, best_stats, 
                         save_best=save_best, save_all=save_all, save_name=save_name)
        
        if args.infer_loss_thresh > 0. and pretrain:
            for i in range(args.num_classes):
                # get the lowest validation loss
                min_loss = min(val_losses[i*num_groups_per_class:(i+1)*num_groups_per_class])
                max_min_loss = max(max_min_loss, min_loss)
            args.logger.info('max_min_loss:', max_min_loss)
            if max_min_loss < args.infer_loss_thresh:
                step = epoch*len(train_loader)+batch_idx+1
                args.logger.info('Saving model at step', step)
                args.infer_steps = [step]
                save_name = os.path.join(args.checkpoint_path, '{}{}/checkpoint_seed{}.pt'.format(args.save_unit, step, args.seed))
                os.makedirs(os.path.dirname(save_name), exist_ok=True)

                args.logger.info(f'Saving model at {args.save_unit} {step} to {save_name}\n')

                # save checkpoint
                torch.save({
                    'model_state_dict': model.state_dict(),
                    'args': args,
                    f'{args.save_unit}': step,
                    'best_stats': best_stats,
                    'test_avg': test_avg,
                    'test_worst': test_worst,
                    'val_avg': avg,
                    'val_worst': worst,
                    'optimizer_state_dict': optimizer.state_dict(),
                    'cluster_labels': cluster_labels,
                }, save_name)

                break

    if scheduler is not None and args.optimizer != 'adamw':
        scheduler.step()

    if args.weighted_sampler and not pretrain:
        # print group counts
        for i in range(args.num_groups):
            args.logger.info(f'Group {i} count: {group_counts[i]}')

    # print learning rate
    args.logger.info('Learning rate: {}'.format(optimizer.param_groups[0]['lr']))

    # args.logger.info epoch stats
    if args.save_unit == 'epoch' or (not pretrain and not args.infer=='none'):
        args.logger.info('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
            epoch, (batch_idx+1), len(train_loader),
            100. * (batch_idx+1) / len(train_loader), losses.avg))   
        
        args.logger.info('Starting validation...')
        avg, worst, val_results, val_losses = validate_by_group(args, val_loader, model, val_criterion)

        # log accuracy for each group and worst group in wandb if use wandb
        if args.use_wandb and (save_best or args.infer_loss_thresh > 0. or args.infer_loss_diff):
            for i in range(len(val_results)):
                wandb.log({f'val_group{i}_acc': val_results[i]})

            for i in range(len(val_losses)):
                wandb.log({f'val_group{i}_loss': val_losses[i]})

        args.logger.info('Starting test...')
        test_avg, test_worst, test_results, _ = validate_by_group(args, test_loader, model, val_criterion)

        # log accuracy for each group and worst group in wandb if use wandb
        if args.use_wandb and save_best:
            for i in range(len(test_results)):
                wandb.log({f'test_group{i}_acc': test_results[i]})

        args.logger.info('Val Avg: {:.4f}, Val Worst: {:.4f}, Test Avg: {:.4f}, Test Worst: {:.4f}\n'.format(
            avg, worst, test_avg, test_worst))

        # save model
        best_stats = log_and_save(args, model, optimizer, epoch+1, losses.avg, avg, worst, test_avg, 
                     test_worst, cluster_labels, best_stats, 
                     save_best=save_best, save_all=save_all, save_name=save_name)

    if args.infer_loss_thresh > 0. and pretrain:
        return best_stats, model, optimizer, scheduler, max_min_loss, 0, args
    elif args.infer_loss_diff and pretrain:
        return best_stats, model, optimizer, scheduler, 0, max_loss_diff, args
    else:
        return best_stats, model, optimizer, scheduler, 0, 0, args


def validate_by_group(args, val_loader_list, model, criterion, weight=None, example_losses=False):
    """
    Run evaluation
    """    
    if example_losses:
        loss_per_example = torch.zeros(args.train_num).cuda()

    # switch to evaluate mode
    model.eval()

    val_results = []
    val_losses = []
    val_sizes = []
    with torch.no_grad():
        for i, val_loader in enumerate(val_loader_list):
            top1 = AverageMeter()
            losses = AverageMeter()
            for _, (input, target, idx) in enumerate(val_loader):
                target = target.cuda()
                input_var = input.cuda()
                target_var = target.cuda()

                # compute output
                output = get_output(model, input_var, target_var, args)
                loss = criterion(output, target_var)
                if example_losses:
                    loss_per_example[idx] = loss
                    loss = loss.mean()
                if weight is not None:
                    loss = (loss * weight[idx.long()]).mean()

                output = output.float()
                loss = loss.float()

                # measure accuracy and record loss
                prec1 = accuracy(output.data, target)[0]
                losses.update(loss.item(), input.size(0))
                top1.update(prec1.item(), input.size(0))

            val_results.append(top1.avg)
            val_losses.append(losses.avg)
            val_sizes.append(len(val_loader.dataset))
            args.logger.info(f' * Group {i} Prec@1 {top1.avg:.3f}\tLoss {losses.avg:.3f}')

    val_sizes = np.array(val_sizes)/sum(val_sizes)
    val_results = np.array(val_results)
    val_losses = np.array(val_losses)
    if args.dataset == 'waterbirds':
        avg = 0.7295099061522419 * val_results[0] + 0.0383733055265902 * val_results[1] + 0.01167883211678832 * val_results[2] + 0.22043795620437956 * val_results[3]
    else:
        avg = val_sizes.dot(val_results)
    worst = val_results.min()
    args.logger.info(f' * Average Prec@1 {avg:.3f}\tWorst Prec@1 {worst:.3f}')

    return avg, worst, val_results, val_losses


def train_refer_model(args, train_criterion, val_criterion, val_loader, test_loader, infer_step=0):
    
    train_dataset = load_dataset(args, split='train', augment=args.infer_augment)   

    if args.class_dro_infer:
        args.group_weighted_loss = GroupWeightedLoss(args, args.num_classes)   
        cluster_labels = np.array(train_dataset.targets)
    else:
        cluster_labels = None        

    # balance classes
    if args.balance_classes_infer:
        sampler = uniform_sample(np.array(train_dataset.targets))
        train_loader = DataLoader(train_dataset, batch_size=args.batch_size, sampler=sampler, num_workers=args.num_workers)
    else:
        train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)

    # load model
    t_total = len(train_loader) * args.epochs
    model, optimizer, scheduler = load_model(args, infer=True, t_total=t_total)

    if args.infer_loss_thresh > 0.:
        best_val_loss = args.infer_loss_thresh
        epoch = 0
        while best_val_loss >= args.infer_loss_thresh:
            _, model, optimizer, scheduler, max_min_loss, _, args = train_and_validate(args, model, optimizer, train_criterion, val_criterion, 
                            train_loader, val_loader, test_loader, epoch=epoch, cluster_labels=cluster_labels, pretrain=True, 
                            save_best=False, save_all=False, save_name=args.checkpoint_path, scheduler=scheduler)
            best_val_loss = min(best_val_loss, max_min_loss)
            epoch += 1
    elif args.infer_loss_diff:
        args.best_loss_diff = 0.0
        if args.sep_infer:
            args.best_loss_diff = np.zeros(args.num_classes)
        epoch = 0
        improved = True
        while improved:
            improved = False
            args.logger.info('training on all data for one epoch and then stopping when loss difference is the largest with seed '+str(args.seed)+'\n')
            _, _, _, _, _, max_loss_diff, args = train_and_validate(args, model, optimizer, train_criterion, val_criterion,  train_loader, val_loader, 
                            test_loader, epoch=epoch, cluster_labels=cluster_labels, pretrain=True, 
                            save_best=False, save_all=True, save_name=args.checkpoint_path, scheduler=scheduler)
            # if any loss difference is larger than the best loss difference, then we continue training
            if args.sep_infer:
                improved = np.any(max_loss_diff > args.best_loss_diff)
            else:
                improved = max_loss_diff > args.best_loss_diff

            args.best_loss_diff = np.maximum(args.best_loss_diff, max_loss_diff)
            epoch += 1

    elif args.save_unit == 'epoch':
        args.logger.info('training on all data for '+str(infer_step)+' epochs with seed '+str(args.seed)+'\n')
        for epoch in range(infer_step):
            _, model, optimizer, scheduler, _, _, _ = train_and_validate(args, model, optimizer, train_criterion, val_criterion, train_loader, 
                            val_loader, test_loader, epoch, cluster_labels=cluster_labels, pretrain=True, 
                            save_best=False, save_all=True, save_name=args.checkpoint_path, scheduler=scheduler)
    else:
        args.logger.info('training on all data for '+str(infer_step)+' batches with seed '+str(args.seed)+'\n')
        train_and_validate(args, model, optimizer, train_criterion, val_criterion,  train_loader, val_loader, 
                        test_loader, epoch=0, cluster_labels=cluster_labels, max_iter=infer_step, pretrain=True, 
                        save_best=False, save_all=True, save_name=args.checkpoint_path, scheduler=scheduler)
        
    return model