import datetime
import logging
import os

import matplotlib.pyplot as plt
import numpy as np
import torch
import umap
from torch import nn
from torch.utils.data import DataLoader

from dataset import *
from inference import *
from sample import *
from train import *
from utils import GroupWeightedLoss, get_parser, load_model, set_logger

# parse arguments
parser = get_parser()
args = parser.parse_args()

# set gpu
# if cuda_visible_devices is not set, set it to the first gpu
if 'CUDA_VISIBLE_DEVICES' not in os.environ:
    os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([str(i) for i in args.gpu])
args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# set class number
if args.dataset in ['waterbirds', 'celeba']:
    args.num_classes = 2
    if args.dataset == 'waterbirds':
        args.epochs = 300
    elif args.dataset == 'celeba':
        args.epochs = 50
elif args.dataset in ['multinli', 'civilcomments']:
    if args.dataset == 'multinli':
        args.num_classes = 3
    else:
        args.num_classes = 2
    args.epochs = 5
    args.optimizer = 'adamw'
    args.max_grad_norm = 1.0
    args.adam_epsilon = 1e-8
    args.warmup_steps = 0
    args.arch = 'bert'
elif 'cmnist' in args.dataset:
    args.num_classes = 5
    args.epochs = 20
    args.arch = 'lenet5'
elif args.dataset == 'cifar100sup':
    args.num_classes = 20
    args.epochs = 200
    args.arch = 'resnet18'
elif args.dataset == 'cifar10':
    args.num_classes = 10
    args.epochs = 200
    args.arch = 'resnet18'
elif args.dataset == 'imagenet':
    args.num_classes = 9
    args.batch_size = 256
    args.epochs = 90
    args.arch = 'resnet50'

# set save path
save_name = f'{args.dataset}'
if 'cmnist' in args.dataset:
    save_name += f'_{args.p_correlation}'
save_name += f'_{args.arch}'   
if args.wide:
    save_name += f'_wide'

if args.infer != 'none' or args.infer_loss_thresh > 0. or args.infer_loss_diff or args.infer == 'ssa':
    save_name += f'_{args.infer}'
    if args.infer_arch != '':
        save_name += f'_{args.infer_arch}'
    save_name += f'_{args.infer_optimizer}_{args.infer_lr:.0e}_{args.infer_weight_decay:.0e}'

    if args.infer == 'eiil':
        save_name += f'_{args.eiil_lr:.0e}_{args.eiil_steps:.0e}'
    elif args.infer == 'dfr':
        if args.dfr_val:
            save_name += f'_val'
        else:
            save_name += f'_train'

    if args.class_dro_infer:
        save_name += f'_class_dro'
    if args.infer_loss_thresh > 0.:
        save_name += f'_loss{args.infer_loss_thresh:.0e}'
    if args.infer_loss_diff:
        save_name += f'_diff'
        if args.sep_infer:
            save_name += f'_sep'
    else:
        if len(args.infer_steps) > 0:
            # format list of infer steps
            save_name += f'_{args.save_unit}'
            for i in range(len(args.infer_steps)):
                save_name += f'{args.infer_steps[i]}'
                if i < len(args.infer_steps) - 1:
                    save_name += '-'
        
    if args.num_infer_ckpts > 1:
        save_name += f'_{args.num_infer_ckpts}ckpts' 
    if args.infer_augment:
        save_name += f'_infer_aug'
    if args.infer == 'cluster':
        save_name += f'_cluster_{args.cluster_metric}'
        if args.traject_length > 1:
            save_name += f'_traject{args.traject_length}'
        if args.cluster_umap:
            save_name += f'_umap'
        save_name += f'_{args.cluster_method}'
        if args.cluster_all:
            save_name += f'_all'
        if args.silhouette:
            save_name += f'_silhouette'
        elif args.dataset in ['cifar100sup', 'cifar10', 'imagenet', 'cmnist', 'balance_cmnist']:
            save_name += f'_{args.num_clusters}clusters'
if args.balance_classes_infer:
    save_name += f'_balance'

if args.infer_only:
    save_name += f'_infer_only'
    if args.sep_conf:
        save_name += f'_sep_conf'
        if args.adaptive_conf_thresh:
            save_name += f'_adap'
        else:
            save_name += f'_{args.conf_thresh:.2f}'
else:
    if args.continue_train:
        save_name += f'_continue'
    save_name += f'_{args.train}_{args.epochs}'
    if args.train_augment:
        save_name += f'_train_aug'
    if args.mixup != 'none':
        save_name += f'_{args.mixup}_{args.mixup_alpha}'
        if args.mixup_lisa:
            save_name += f'_lisa'
        if args.mixup_sample_by_cluster:
            save_name += f'_sample_by_cluster_{args.mixup_sample_by_cluster_power}'
    if args.weight_loss:
        save_name += f'_weight_loss_{args.weight_loss_power}'
    if args.sample_by_cluster:
        save_name += f'_sample_by_cluster'
        if args.adaptive_sample_power:
            save_name += f'_adap'
        else:
            if args.upsample_by_cluster_size:
                save_name += f'_upsample_size'
            save_name += f'_{args.sample_by_cluster_power}'
    if args.sample != 'none':
        save_name += f'_{args.sample}'
        if args.sample == 'upsample_by_factor':
            save_name += f'_{args.upsample_factor}'
    if args.sample_every > 0:
        save_name += f'_every_{args.sample_every}'
    else:
        if args.uniform_group_sampler:
            save_name += '_uniform-group'
        elif args.uniform_class_sampler:
            save_name += '_uniform-class'
        if args.weighted_sampler:
            save_name += '_sampler'

    if args.scheduler != 'none':
        save_name += f'_{args.scheduler}_{args.milestones}'

save_name += f'_{args.optimizer}_{args.lr:.0e}_{args.momentum}_{args.weight_decay:.0e}_bs{args.batch_size}_seed{args.seed}'
args.save_dir = os.path.join(args.save_dir, save_name)

# set wandb
if args.use_wandb:
    import wandb
    wandb.init(project=args.wandb_project, config=args, name=args.save_dir.split('/')[-1])

# add timestamp to save_dir
args.save_dir = os.path.join(args.save_dir, str(datetime.datetime.now()).replace(' ', '_').replace(':', '_').replace('.', '_'))

# set logger
if not os.path.exists(args.save_dir):
    os.makedirs(args.save_dir)

logger = logging.getLogger(__name__)
args.logger = set_logger(args, logger)
args.logger.info(str(args) + ' ' + str(datetime.datetime.now()) + '\n')


def main():

    # load data
    train_dataset = load_dataset(args, split='train', augment=args.train_augment)
    args.num_groups = np.unique(train_dataset.group_array).shape[0]

    # balance classes
    if args.balance_classes_infer:
        args.logger.info('Balancing classes for inference')
        sampler = uniform_sample(np.array(train_dataset.targets))
        train_loader = DataLoader(train_dataset, batch_size=args.batch_size, sampler=sampler, num_workers=args.num_workers)
    else:
        train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
    train_val_loader = DataLoader(load_dataset(args, split='train', augment=False), batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)

    args.train_size = len(train_dataset)

    args.save_freq = args.epochs // 5 if args.save_unit == 'epoch' else args.train_size // args.batch_size
    
    # log group size
    group_sizes = np.bincount(train_dataset.group_array)
    args.logger.info(f'Group sizes: {group_sizes}')

    if args.adaptive_sample_power:
        # sum every two groups and repeat by 2
        group_sizes_ = np.array([np.sum(group_sizes[i:i+(args.num_groups//args.num_classes-1)]) for i in range(0, len(group_sizes), args.num_groups//args.num_classes)])
        group_sizes_ = np.repeat(group_sizes_, args.num_groups//args.num_classes)
        args.sample_by_cluster_power = np.sqrt(group_sizes_ / group_sizes)
        args.logger.info(f'Adaptive sample power: {args.sample_by_cluster_power}')

    train_criterion = nn.CrossEntropyLoss(reduction='none')
    val_criterion = nn.CrossEntropyLoss()

    if args.dataset in ['cifar100sup', 'cifar10', 'imagenet']:
        # load validation data
        val_loader = torch.utils.data.DataLoader(
                load_dataset(args, split='val'),
                batch_size=args.batch_size, shuffle=False,
                num_workers=args.num_workers, pin_memory=True)

        test_loader = torch.utils.data.DataLoader(
                load_dataset(args, split='test'),
                batch_size=args.batch_size, shuffle=False,
                num_workers=args.num_workers, pin_memory=True)

    else:
        # load validation data
        val_loader = [torch.utils.data.DataLoader(
                load_dataset(args, split='val', group=group),
                batch_size=args.batch_size, shuffle=False,
                num_workers=args.num_workers, pin_memory=True) 
                for group in [[i, j] for i in range(args.num_classes) for j in range(len(group_sizes)//args.num_classes)]]

        test_loader = [torch.utils.data.DataLoader(
                load_dataset(args, split='test', group=group),
                batch_size=args.batch_size, shuffle=False,
                num_workers=args.num_workers, pin_memory=True) 
                for group in [[i, j] for i in range(args.num_classes) for j in range(len(group_sizes)//args.num_classes)]]
        
        # print the size of each validation set
        for i, loader in enumerate(val_loader):
            args.logger.info(f'Validation set size for group {i}: {len(loader.dataset)}')

        # print the size of each test set
        for i, loader in enumerate(test_loader):
            args.logger.info(f'Test set size for group {i}: {len(loader.dataset)}')
    


    # train model on all data for a few epochs
    if len(args.infer_steps) > 0 or args.infer_loss_thresh > 0. or args.infer_loss_diff or args.infer == 'ssa':
        if args.infer_arch == '':
            arch = args.arch
        elif 'resnet' in args.infer_arch:
            arch = args.infer_arch[:8]
        else:
            arch = args.infer_arch
        args.checkpoint_path = os.path.join(args.checkpoint_path, f'{args.dataset}_{arch}', f'{args.infer_optimizer}_{args.infer_lr:.0e}_{args.infer_weight_decay:.0e}')
        if args.wide or 'wide' in args.infer_arch:
            args.checkpoint_path += f'_wide'
        if not args.pretrained or args.infer_arch != '':
            args.checkpoint_path += f'_no_pretrained'
        if args.freeze:
            args.checkpoint_path += f'_freeze'
        if args.infer_augment:
            args.checkpoint_path += '_aug'
        if args.balance_classes_infer:
            args.checkpoint_path +=  f'_balance'
        if args.class_dro_infer:
            args.checkpoint_path += f'_classdro'
        if 'cmnist' in args.dataset:
            args.checkpoint_path += f'_bs{args.batch_size}'

    # load model
    epoch_steps = len(train_loader)
    args.logger.info(f'Epoch steps: {epoch_steps}')
    t_total = epoch_steps * args.epochs
    model, optimizer, scheduler = load_model(args, infer=False, t_total=t_total)

    if not args.infer_only and not args.continue_train:
        # save initial model
        save_name = os.path.join(args.save_dir, 'init_model.pth.tar')
        torch.save(model.state_dict(), save_name)

    if len(args.infer_steps) > 0 or args.infer_loss_thresh > 0. or args.infer_loss_diff:
        for seed in range(args.num_infer_ckpts):    
            if args.num_infer_ckpts > 1:
                args.seed = seed
            
            if len(args.infer_steps) > 0:
                for step in np.sort(args.infer_steps)[::-1]:
                    ckpt_path = os.path.join(args.checkpoint_path, '{}{}/checkpoint_seed{}.pt'.format(args.save_unit, step, args.seed))
                    if not os.path.exists(ckpt_path):
                        args.logger.info('Training refer model')
                        model = train_refer_model(args, train_criterion, val_criterion, val_loader, test_loader, step)

                        args.logger.info('Saving refer model')
                        os.makedirs(os.path.dirname(ckpt_path), exist_ok=True)
                        torch.save({
                            'model_state_dict': model.state_dict(),
                            }, ckpt_path)
            elif args.infer_loss_thresh > 0. or args.infer_loss_diff:
                args.logger.info('Training refer model')
                model = train_refer_model(args, train_criterion, val_criterion, val_loader, test_loader)

            for infer_step in args.infer_steps:
                # save embeddings and predictions
                activation_path = os.path.join(args.checkpoint_path, 
                                                '{}{}/activations_seed{}.pt'.format(args.save_unit, infer_step, args.seed))    
                if not os.path.exists(activation_path):
                    os.makedirs(os.path.dirname(activation_path), exist_ok=True)
                    args.logger.info('Saving activations')

                    ckpt_path = os.path.join(args.checkpoint_path, '{}{}/checkpoint_seed{}.pt'.format(args.save_unit, infer_step, args.seed))
                    checkpoint = torch.load(ckpt_path)
                    model.load_state_dict(checkpoint['model_state_dict'])
                    embeds, preds, outputs, labels = save_activations(model, train_val_loader, args)
                    torch.save({
                        'embeds': embeds,
                        'preds': preds,
                        'outputs': outputs,
                        'labels': labels,
                    }, activation_path)                   
                
        # load data
        train_dataset = load_dataset(args, split='train', augment=args.train_augment)
        train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)

    # infer group membership
    if args.infer == 'none':
        cluster_labels = np.array(train_dataset.group_array)
        if args.dataset in ['cifar100sup', 'cifar10', 'imagenet']:
            val_loader = [val_loader]
            test_loader = [test_loader]
    else:
        
        if args.dataset in ['cifar100sup', 'cifar10', 'imagenet']:
            cluster_labels, test_loader = infer_one_step(args, model, train_dataset, train_val_loader, val_loader, args.infer_steps, test_loader)
            val_loader = [val_loader]
            test_loader = [test_loader]
        else:
            cluster_labels = infer_one_step(args, model, train_dataset, train_val_loader, val_loader, args.infer_steps)
    
    args.num_groups = len(np.unique(cluster_labels))

    if args.visualize and args.infer != 'ssa':
        colors = np.array(cluster_labels).astype(int)
        num_colors = len(np.unique(colors))
        if args.cluster_umap:
            umap_ = umap.UMAP(random_state=args.seed, 
                            n_components=2)
            X = umap_.fit_transform(embeds)
        else:
            X = embeds
        plt.scatter(X[:, 0], X[:, 1], c=colors, s=1.0,
                    cmap=plt.cm.get_cmap(args.cmap, num_colors))
        plt.colorbar(ticks=np.unique(colors))
        fpath = os.path.join(args.save_dir,
                                f'umap-init_slice-cr.png')
        plt.savefig(fname=fpath, dpi=300, bbox_inches="tight")
        plt.close()
        args.logger.info(f'Saved UMAP to {fpath}!')
        
        # Save based on other write too
        for target_type in ['target', 'spurious']:
            if target_type == 'target':
                colors = np.array(labels).astype(int)
            else:
                colors = np.array(train_dataset.dataset.confounder_array[train_dataset.dataset.split_array==0]).astype(int)
            num_colors = len(np.unique(colors))
            plt.scatter(X[:, 0], X[:, 1], c=colors, s=1.0,
                        cmap=plt.cm.get_cmap(args.cmap, num_colors))
            plt.colorbar(ticks=np.unique(colors))
            t = f'{target_type[0]}{target_type[-1]}'
            fpath = os.path.join(args.save_dir,
                                    f'umap-init_slice-{t}.png')
            plt.savefig(fname=fpath, dpi=300, bbox_inches="tight")
            args.logger.info(f'Saved UMAP to {fpath}!')
            plt.close()

    if args.infer_only:
        return
    
    if not args.continue_train:
        model, optimizer, scheduler = load_model(args, infer=False, t_total=t_total)

        # load the initial model
        model.load_state_dict(torch.load(os.path.join(args.save_dir, 'init_model.pth.tar')))
    
        # delete the initial model
        os.remove(os.path.join(args.save_dir, 'init_model.pth.tar'))

    if args.sample != 'none' or args.uniform_group_sampler or args.weighted_sampler or args.sample_by_cluster:
        # sample training data
        train_dataset, train_loader = sample_train_data(args, train_dataset, cluster_labels)

    # store all best val and test statistics together
    best_stats = {'best_avg_val': 0, 'best_worst_val': 0, 'best_avg_test': 0, 'best_worst_test': 0}

    if args.train == 'group_dro':
        args.group_weighted_loss = GroupWeightedLoss(args, args.num_groups)

    # train model for args.epochs epochs
    args.logger.info('Training model...')
    for epoch in range(args.epochs):
        best_stats, model, optimizer, scheduler, _, _, _ = train_and_validate(
            args, model, optimizer, train_criterion, val_criterion, train_loader, val_loader, test_loader, 
            epoch, save_best=True, cluster_labels=cluster_labels, best_stats=best_stats, scheduler=scheduler)        

        # sample by cluster if specified
        if args.sample_every > 0 and epoch % args.sample_every == 0 and epoch < args.epochs - 1:
            # sample training data
            train_dataset = load_dataset(args, split='train', augment=args.train_augment)            
            train_dataset, train_loader = sample_train_data(args, train_dataset, cluster_labels, preds)
        
    # finish training
    args.logger.info('Finished training!')

    # finish wandb run
    if args.use_wandb:
        wandb.finish()

if __name__ == '__main__':
    main()