import datetime
import os
import sys

import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader

from dataset import *
from sample import *
from train import train_and_validate
from utils import *
from inference import infer_one_step

# parse arguments
parser = get_parser()
args = parser.parse_args()

# set gpu
# if cuda_visible_devices is not set, set it to the first gpu
if 'CUDA_VISIBLE_DEVICES' not in os.environ:
    os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([str(i) for i in args.gpu])
args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# set class number
if args.dataset in ['waterbirds', 'celeba']:
    args.num_classes = 2
    if args.dataset == 'waterbirds':
        args.epochs = 300
    elif args.dataset == 'celeba':
        args.epochs = 50
elif args.dataset in ['multinli', 'civilcomments']:
    if args.dataset == 'multinli':
        args.num_classes = 3
    else:
        args.num_classes = 2
    args.epochs = 5
    args.optimizer = 'adamw'
    args.max_grad_norm = 1.0
    args.adam_epsilon = 1e-8
    args.warmup_steps = 0
    args.arch = 'bert'
elif 'cmnist' in args.dataset:
    args.num_classes = 5
    args.epochs = 20
    args.arch = 'lenet5'
elif args.dataset == 'cifar100sup':
    args.num_classes = 20
    args.epochs = 200
    args.arch = 'resnet18'
elif args.dataset == 'cifar10':
    args.num_classes = 10
    args.epochs = 200
    args.arch = 'resnet18'
elif args.dataset == 'imagenet':
    args.num_classes = 9
    args.batch_size = 256
    args.epochs = 90
    args.arch = 'resnet50'

# set save path
save_name = f'{args.dataset}'
if 'cmnist' in args.dataset:
    save_name += f'_{args.p_correlation}'
save_name += f'_{args.arch}'   
if args.wide:
    save_name += f'_wide'
save_name += f'_infer_all'

save_name += f'_{args.infer}_{args.infer_optimizer}_{args.infer_lr:.0e}_{args.infer_weight_decay:.0e}'
if args.class_dro_infer:
    save_name += f'_class_dro'
save_name += f'_{args.infer_steps[0]}steps'

if not args.pretrained:
    save_name += f'_no_pretrained'
if args.freeze:
    save_name += f'_freeze'
    
if args.num_infer_ckpts > 1:
    save_name += f'_{args.num_infer_ckpts}ckpts' 
if args.infer_augment:
    save_name += f'_infer_aug'
if args.infer == 'cluster':
    save_name += f'_cluster_{args.cluster_metric}'
    if args.traject_length > 1:
        save_name += f'_traject{args.traject_length}'
    if args.cluster_umap:
        save_name += f'_umap'
    save_name += f'_{args.cluster_method}'
    if args.cluster_all:
        save_name += f'_all'
    if args.silhouette:
        save_name += f'_silhouette'
    elif args.dataset in ['cifar100sup', 'cifar10', 'imagenet', 'cmnist', 'balance_cmnist']:
        save_name += f'_{args.num_clusters}clusters'
if args.balance_classes_infer:
    save_name += f'_balance'

save_name += f'_bs{args.batch_size}_seed{args.seed}'
args.save_dir = os.path.join(args.save_dir, save_name)

# set wandb
if args.use_wandb:
    import wandb
    wandb.init(project=args.wandb_project, config=args, name=args.save_dir.split('/')[-1])

# add timestamp to save_dir
args.save_dir = os.path.join(args.save_dir, str(datetime.datetime.now()).replace(' ', '_').replace(':', '_').replace('.', '_'))

# set logger
if not os.path.exists(args.save_dir):
    os.makedirs(args.save_dir)

logger = logging.getLogger(__name__)
args.logger = set_logger(args, logger)
args.logger.info(str(args) + ' ' + str(datetime.datetime.now()) + '\n')


def main():

    # load data
    train_dataset = load_dataset(args, split='train', augment=args.infer_augment)
    args.num_groups = np.unique(train_dataset.group_array).shape[0]

    print('Number of groups:', args.num_groups)
    print('Number of classes:', args.num_classes)

    # balance classes
    if args.balance_classes_infer:
        args.logger.info('Balancing classes for inference')
        sampler = uniform_sample(np.array(train_dataset.targets))
        train_loader = DataLoader(train_dataset, batch_size=args.batch_size, sampler=sampler, num_workers=args.num_workers)
    else:
        train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
    train_val_loader = DataLoader(load_dataset(args, split='train', augment=False), batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)

    args.train_size = len(train_dataset)

    args.save_freq = args.epochs // 5 if args.save_unit == 'epoch' else args.train_size // args.batch_size
    
    # log group size
    group_sizes = np.bincount(train_dataset.group_array)
    args.logger.info(f'Group sizes: {group_sizes}')

    train_criterion = nn.CrossEntropyLoss(reduction='none')
    val_criterion = nn.CrossEntropyLoss()

    if args.dataset in ['cifar100sup', 'cifar10', 'imagenet']:
        # load validation data
        val_loader = torch.utils.data.DataLoader(
                load_dataset(args, split='val'),
                batch_size=args.batch_size, shuffle=False,
                num_workers=args.num_workers, pin_memory=True)

        test_loader = torch.utils.data.DataLoader(
                load_dataset(args, split='test'),
                batch_size=args.batch_size, shuffle=False,
                num_workers=args.num_workers, pin_memory=True)
    else:
        # load validation data
        val_loader = [torch.utils.data.DataLoader(
                load_dataset(args, split='val', group=group),
                batch_size=args.batch_size, shuffle=False,
                num_workers=args.num_workers, pin_memory=True) 
                for group in [[i, j] for i in range(args.num_classes) for j in range(len(group_sizes)//args.num_classes)]]

        test_loader = [torch.utils.data.DataLoader(
                load_dataset(args, split='test', group=group),
                batch_size=args.batch_size, shuffle=False,
                num_workers=args.num_workers, pin_memory=True) 
                for group in [[i, j] for i in range(args.num_classes) for j in range(len(group_sizes)//args.num_classes)]]

    # load model
    epoch_steps = len(train_loader)
    args.logger.info(f'Epoch steps: {epoch_steps}')
    t_total = epoch_steps * args.epochs
    model, optimizer, scheduler = load_model(args, infer=True, t_total=t_total)

    # train model on all data for a few epochs
    if 'cmnist' in args.dataset:
        args.checkpoint_path = os.path.join(args.checkpoint_path, f'{args.dataset}_{args.p_correlation}_{args.arch}', f'{args.infer_optimizer}_{args.infer_lr:.0e}_{args.infer_weight_decay:.0e}')
    else:
        args.checkpoint_path = os.path.join(args.checkpoint_path, f'{args.dataset}_{args.arch}', f'{args.infer_optimizer}_{args.infer_lr:.0e}_{args.infer_weight_decay:.0e}')
    if args.wide:
        args.checkpoint_path += f'_wide'
    if not args.pretrained:
        args.checkpoint_path += f'_no_pretrained'
    if args.freeze:
        args.checkpoint_path += f'_freeze'
    if args.infer_augment:
        args.checkpoint_path += '_aug'
    if args.balance_classes_infer:
        args.checkpoint_path +=  f'_balance'
    if args.class_dro_infer:
        args.checkpoint_path += f'_classdro'
    if 'cmnist' in args.dataset:
        args.checkpoint_path += f'_bs{args.batch_size}'

    infer_freq = args.infer_steps[0]//args.num_infer_ckpts
    args.infer_steps = np.arange(infer_freq, args.infer_steps[0]+1, infer_freq)
    if args.include_init:
        args.infer_steps = np.concatenate(([0], args.infer_steps))

    if args.save_unit == 'batch':
        train_loader_iter = iter(train_loader)

    for step in args.infer_steps: 
        args.logger.info('Infer step: '+str(step))

        torch.cuda.empty_cache()
        sys.stdout.flush()

        step = int(step)
        ckpt_path = os.path.join(args.checkpoint_path, '{}{}/checkpoint_seed{}.pt'.format(args.save_unit, step, args.seed))
        
        model.train()
        if step != 0:
            try:
                args.logger.info('loading model at '+ckpt_path+'\n')
                checkpoint = torch.load(ckpt_path)
                model.load_state_dict(checkpoint['model_state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                if scheduler is not None and step != args.infer_steps[-1]:
                    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
            except:
                if args.save_unit == 'epoch':            
                    args.logger.info('training on all data for '+str(infer_freq)+' epochs with seed '+str(args.seed)+'\n')
                    for epoch in range(infer_freq):
                        _, model, optimizer, scheduler, _, _, _ = train_and_validate(args, model, optimizer, train_criterion, val_criterion, train_loader, 
                                        val_loader, test_loader, epoch, pretrain=True, 
                                        save_best=False, save_all=True, save_name=args.checkpoint_path, scheduler=scheduler)
                else:
                    args.logger.info('training on all data for '+str(infer_freq)+' batches with seed '+str(args.seed)+'\n')
                    batch_idx = 0
                    while batch_idx < infer_freq:
                        try:
                            data, target, _  = next(train_loader_iter)
                        except StopIteration:
                            train_loader_iter = iter(train_loader)
                            data, target, _  = next(train_loader_iter)
                        batch_idx += 1
                        data, target = data.to(args.device), target.to(args.device)

                        # zero the parameter gradients and forward pass
                        optimizer.zero_grad()
                        output = get_output(model, data, target, args)

                        # compute loss
                        loss = train_criterion(output, target).mean()
                        if args.arch == 'bert' and args.optimizer == 'adamw':
                            loss.backward()
                            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
                            if scheduler is not None:
                                scheduler.step()
                            optimizer.step()
                            model.zero_grad()
                        else:
                            # zero the parameter gradients and forward pass
                            optimizer.zero_grad()
                            loss.backward()
                            optimizer.step()
                if scheduler is not None and args.optimizer != 'adamw' and ((step + batch_idx + 1) % epoch_steps == 0):
                    args.logger.info('scheduler step\n')
                    scheduler.step()
            
                args.logger.info('saving model at '+ckpt_path+'\n')
                os.makedirs(os.path.dirname(ckpt_path), exist_ok=True)
                torch.save({
                    'model_state_dict': model.state_dict(),
                    'args': args,
                    f'{args.save_unit}': step,
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict() if scheduler is not None else None,
                }, ckpt_path)

        model.eval()

        if args.dataset in ['cifar100sup', 'cifar10', 'imagenet']:
            cluster_labels, _ = infer_one_step(args, model, train_dataset, train_val_loader, val_loader, [step], test_loader)
        else:
            cluster_labels = infer_one_step(args, model, train_dataset, train_val_loader, val_loader, [step])

        if args.dataset not in ['cifar100sup', 'cifar10', 'imagenet']:
            # compute the recall and precision each group and predicted clusters
            group_recall = []
            log_dict = {}
            groups = np.array(train_dataset.group_array)
            # find the largest cluster every args.num_clusters clusters
            cluster_sizes = np.bincount(cluster_labels)
            cluster_sizes = cluster_sizes.reshape(args.num_classes, -1)
            largest_cluster_indices = cluster_sizes.argmax(axis=1)
            largest_cluster_indices += np.arange(0, len(largest_cluster_indices)*args.num_clusters, args.num_clusters)
            small_cluster_indices = np.setdiff1d(np.arange(args.num_clusters*args.num_classes), largest_cluster_indices)

            # find the largest group for every class
            group_sizes = np.bincount(groups)
            group_sizes = group_sizes.reshape(args.num_classes, -1)
            largest_group_indices = group_sizes.argmax(axis=1)
            largest_group_indices += np.arange(0, len(largest_group_indices)*args.num_groups//args.num_classes, args.num_groups//args.num_classes)
            small_group_indices = np.setdiff1d(np.arange(args.num_clusters*args.num_classes), largest_group_indices)

            for group in np.unique(groups):
                if group in largest_group_indices:
                    recall = np.mean(np.isin(cluster_labels[groups == group], largest_cluster_indices))
                else:
                    recall = np.mean(np.isin(cluster_labels[groups == group], small_cluster_indices))
                group_recall.append(recall)
                args.logger.info(f'Group {group} recall: {recall}')
                log_dict[f'Recall/group_{group}'] = recall
            group_recall = np.array(group_recall)

            if args.use_wandb:
                log_dict['Recall/mean'] = group_recall.mean()
                log_dict['Recall/large'] = group_recall[largest_group_indices].mean()
                log_dict['Recall/small'] = group_recall[small_group_indices].mean()
                log_dict['Recall/step'] = step
                wandb.log(log_dict)

        # flush
        plt.close('all')
        sys.stdout.flush()
    

    # finish wandb run
    if args.use_wandb:
        wandb.finish()

if __name__ == '__main__':
    main()