import os
import math

import torch
import torch.utils.data as torch_data
from tqdm.auto import tqdm
from transformers import AutoModelForSeq2SeqLM, AutoModelForSequenceClassification
from accelerate.logging import get_logger

import utils
import losses


logger = get_logger(__name__, log_level='INFO')

EVAL_MERGE_CLASS_LIST_DICT = {
    'hans': [[1, 2]],
}
GROUP_ATTRIBUTE_DICT = {
    'neg_mnli': ['sentence2_has_negation']
}
EVAL_METRIC_DICT = {
    'nli': 'acc', 
    'td': 'macro_f1', 
}


def evaluate_model(model, eval_dataloader_dict, accelerator, 
                   metric_name='acc', generative_forward=False, verbalizer_indices=None):
    model.eval()
    evaluation_result, evaluation_logits_dict = {}, {}
    for dataset_name in eval_dataloader_dict:
        accelerator.print(f'Evaluating on {dataset_name}')
        merge_class_list = EVAL_MERGE_CLASS_LIST_DICT[dataset_name] \
            if dataset_name in EVAL_MERGE_CLASS_LIST_DICT else None
        dataset_performance, dataset_logits = utils.eval_acc(
            model, eval_dataloader_dict[dataset_name], accelerator, 
            merge_class_list=merge_class_list, return_logits=True, metric_name=metric_name, 
            generative_forward=generative_forward, verbalizer_indices=verbalizer_indices
        )
        evaluation_result[dataset_name] = dataset_performance
        evaluation_logits_dict[dataset_name] = dataset_logits
    evaluation_str = '\n'.join([f'{dataset_name} {metric_name}: ' + '{:.4f}'.format(evaluation_result[dataset_name])
                                for dataset_name in evaluation_result])
    accelerator.print(evaluation_str)

    model.train()
    return evaluation_result, evaluation_logits_dict


def test_saved_models(test_data_dict, dataloader, data_collator, accelerator, args, train_data=None):
    test_dataloader_dict = {dataset_name: torch_data.dataloader.DataLoader(
        test_data_dict[dataset_name], shuffle=False, batch_size=args.eval_batch_size, collate_fn=data_collator) 
        for dataset_name in test_data_dict}
    test_datasets = list(test_dataloader_dict.keys())

    generative_forward = True if args.model_name in dataloader.model_types['generative'] else False
    verbalizer_indices = dataloader.verbalizer_indices_dict[args.model_name] if generative_forward else None

    for model_dir in os.listdir(args.save_dir):
        # Select directories with saved models
        if not model_dir.startswith('step'):
            continue
        model_dir = os.path.join(args.save_dir, model_dir)
        if not os.path.isdir(model_dir): 
            continue
        bin_filenames = [filename for filename in os.listdir(model_dir) if filename.endswith('.bin')]
        if len(bin_filenames) != 1:
            continue

        # Load and prepare model and dataloaders
        accelerator.print(f'Testing model {os.path.basename(model_dir)}')
        if args.model_name in dataloader.model_types['discriminative']: 
            model = AutoModelForSequenceClassification.from_pretrained(model_dir)
        # Generative models use sequence-to-sequence models
        else: 
            model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)

        prepared_model_test_dataloaders = accelerator.prepare(
            model, *[test_dataloader_dict[dataset_name] for dataset_name in test_datasets])
        model, prepared_test_dataloaders = prepared_model_test_dataloaders[0], \
            {test_dataset: prepared_test_dataloader for test_dataset, prepared_test_dataloader in zip(
            test_datasets, prepared_model_test_dataloaders[1:])}
        test_results_dict, test_logits_dict = evaluate_model(
            model, prepared_test_dataloaders, accelerator, metric_name=EVAL_METRIC_DICT[args.task_name], 
            generative_forward=generative_forward, verbalizer_indices=verbalizer_indices)

        # Compute train data logits
        train_logits = None
        if args.compute_train_logits:
            assert train_data is not None
            train_logits = utils.compute_logits(
                model, train_data, args.eval_batch_size, data_collator, accelerator, 
                generative_forward=generative_forward, verbalizer_indices=verbalizer_indices)

        # Save results and delete model if necessary (to save disk space)
        utils.save_model(
            model_dir, accelerator, prefix='test', model=None, 
            eval_result=test_results_dict, eval_logits_dict=test_logits_dict, train_logits=train_logits)
        if not args.keep_weights_after_test and accelerator.is_local_main_process:
            os.remove(os.path.join(model_dir, bin_filenames[0]))


def train_erm(model, train_data, eval_data_dict, test_data_dict, optimizer, scheduler, 
              dataloader, data_collator, accelerator, args):
    model.train()

    # Prepare dataloaders
    train_dataloader = torch_data.dataloader.DataLoader(
        train_data, shuffle=True, batch_size=args.train_batch_size, collate_fn=data_collator)
    eval_dataloader_dict = {dataset_name: torch_data.dataloader.DataLoader(
        eval_data_dict[dataset_name], shuffle=False, batch_size=args.eval_batch_size, collate_fn=data_collator)
        for dataset_name in eval_data_dict}
    test_dataloader_dict = {dataset_name: torch_data.dataloader.DataLoader(
        test_data_dict[dataset_name], shuffle=False, batch_size=args.eval_batch_size, collate_fn=data_collator)
        for dataset_name in test_data_dict}
    model, optimizer, scheduler, train_dataloader = accelerator.prepare(
        model, optimizer, scheduler, train_dataloader)
    eval_dataloader_dict = {dataset_name: accelerator.prepare(eval_dataloader_dict[dataset_name]) 
                            for dataset_name in eval_data_dict}
    test_dataloader_dict = {dataset_name: accelerator.prepare(test_dataloader_dict[dataset_name]) 
                            for dataset_name in test_dataloader_dict}
    # Training with generative models (e.g., T5) has to forward with labels (i.e. teacher forcing)
    generative_forward = True if args.model_name in dataloader.model_types['generative'] else False
    verbalizer_indices = dataloader.verbalizer_indices_dict[args.model_name] if generative_forward else None

    # Evaluation interval should be specified in terms of mini-steps per update
    eval_interval = args.eval_interval if args.eval_interval else \
        math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
    accelerator.wait_for_everyone()

    completed_steps, avg_loss, epoch_idx = 0, 0, 0
    # Save logits and data indices for each epoch
    logits_dict, data_indices_dict = {}, {} 
    progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
    while completed_steps < args.max_train_steps:
        logits_dict[epoch_idx] = torch.zeros((args.num_samples, args.num_labels))
        epoch_logits, epoch_data_indices = [], []
        for batch in train_dataloader:
            with accelerator.accumulate(model):
                logits, _ = utils.forward_model(model, batch, generative=generative_forward, verbalizer_indices=verbalizer_indices)
                batch_labels = batch['restricted_labels'] if 'restricted_labels' in batch else batch['labels']
                with accelerator.autocast():
                    if args.erm_loss_func == 'cross_entropy':
                        loss = losses.compute_ce(logits, batch_labels).mean()
                    elif args.erm_loss_func == 'generalized_cross_entropy':
                        loss = losses.compute_gce(logits, batch_labels, args.gce_q).mean()
                accelerator.backward(loss)
                if args.max_grad_norm is not None and args.optimizer != 'adafactor':
                    accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm)
                optimizer.step()
                if scheduler is not None and args.optimizer == 'adam':
                    if not accelerator.optimizer_step_was_skipped:
                        scheduler.step()  # Adjust LR by step (Adam's default)
                optimizer.zero_grad()
            
            # Average loss and steps should be computed each gradient accumulation step
            if accelerator.sync_gradients:
                progress_bar.update(1)
                completed_steps += 1
                avg_loss = (completed_steps - 1) / completed_steps * avg_loss + loss.detach() / completed_steps
                if args.use_wandb:
                    accelerator.log({
                        'step': completed_steps,
                        'train_loss': loss.detach().item(), 
                        'train_avg_loss': avg_loss.item()
                    })

            data_indices, logits = accelerator.gather((batch['data_idx'], logits))
            epoch_data_indices.append(data_indices.cpu())
            epoch_logits.append(logits.detach().type(torch.float32).cpu())

            # evaluation
            if (completed_steps % eval_interval == 0 or completed_steps == args.max_train_steps) \
                and completed_steps != 0 and not args.no_checkpoint and accelerator.sync_gradients:
                model.eval()
                avg_loss = accelerator.reduce(avg_loss, reduction='mean')
                accelerator.print(f'Evaluating at step #{completed_steps} in epoch {epoch_idx}')
                accelerator.print(f'average loss: {avg_loss.item()}')
                # Run reduce on plateau scheduler
                if args.scheduler == 'reduce_lr_on_plateau':
                    scheduler.step(avg_loss)

                validation_result, validation_logits = evaluate_model(
                    model, eval_dataloader_dict, accelerator, metric_name=EVAL_METRIC_DICT[args.task_name], 
                    generative_forward=generative_forward, verbalizer_indices=verbalizer_indices
                )
                test_result, test_logits = evaluate_model(
                    model, test_dataloader_dict, accelerator, metric_name=EVAL_METRIC_DICT[args.task_name], 
                    generative_forward=generative_forward, verbalizer_indices=verbalizer_indices
                )
                if args.use_wandb and accelerator.is_local_main_process:
                    accelerator.log({
                        'step': completed_steps,  
                        **{f'val_{dataset_name}': validation_result[dataset_name] for dataset_name in validation_result}, 
                        **{f'test_{dataset_name}': test_result[dataset_name] for dataset_name in test_result}, 
                    })

                accelerator.print(f'Saving at step #{completed_steps} in epoch {epoch_idx}')
                save_path = os.path.join(args.save_dir, f'step_{completed_steps}')
                if accelerator.is_local_main_process:
                    os.makedirs(save_path, exist_ok=True)
                accelerator.wait_for_everyone()
                utils.save_model(
                    save_path, accelerator, prefix='validation', model=model if args.save_every else None, 
                    eval_result=validation_result, eval_logits_dict=validation_logits)
                utils.save_model(
                    save_path, accelerator, prefix='test', model=None, 
                    eval_result=test_result, eval_logits_dict=test_logits)

            model.train()
            if completed_steps == args.max_train_steps:
                break

        # End of epoch: save logits and data indices
        epoch_logits, epoch_data_indices = torch.cat(epoch_logits, dim=0), torch.cat(epoch_data_indices, dim=0)
        data_indices_dict[epoch_idx] = epoch_data_indices
        logits_dict[epoch_idx][data_indices_dict[epoch_idx]] = epoch_logits
        epoch_idx += 1

    model.eval()
    # save the final model if necessary
    save_path = os.path.join(args.save_dir, f'step_{completed_steps}')
    if not os.path.exists(save_path):
        accelerator.print(f'Evaluating and saving the final model at step #{completed_steps} in epoch {epoch_idx}')
        if accelerator.is_local_main_process:
            os.makedirs(save_path, exist_ok=True)
        validation_result, validation_logits = evaluate_model(
            model, eval_dataloader_dict, accelerator, metric_name=EVAL_METRIC_DICT[args.task_name], 
            generative_forward=generative_forward, verbalizer_indices=verbalizer_indices
        )
        test_result, test_logits = evaluate_model(
            model, test_dataloader_dict, accelerator, metric_name=EVAL_METRIC_DICT[args.task_name], 
            generative_forward=generative_forward, verbalizer_indices=verbalizer_indices
        )
        accelerator.wait_for_everyone()
        utils.save_model(
            save_path, accelerator, prefix='validation', model=model if args.save_last else None, 
            eval_result=validation_result, eval_logits_dict=validation_logits)
        utils.save_model(
            save_path, accelerator, prefix='test', model=None, 
            eval_result=test_result, eval_logits_dict=test_logits)

    if accelerator.is_local_main_process:
        torch.save(logits_dict, os.path.join(args.save_dir, 'logits_log.pt'))
        torch.save(data_indices_dict, os.path.join(args.save_dir, 'data_indices_log.pt'))

    return model, logits_dict

    