import os
from functools import partial

os.environ['TRANSFORMERS_CACHE'] = 'data/hg_data/transformers'
os.environ['HF_DATASETS_CACHE'] = 'data/hg_data/datasets'
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import torch
import torch.utils.data as torch_data
from transformers import (
    AutoConfig, 
    AutoModelForSequenceClassification, 
    AutoModelForSeq2SeqLM,
    DataCollatorWithPadding, 
    DataCollatorForSeq2Seq, 
)
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed, broadcast
from tqdm.auto import tqdm
from ptflops import get_model_complexity_info

from dataloader import DATALOADER_DICT
from config import parse_args


def construct_inputs(batch, useful_keys=None):
    assert useful_keys is not None
    batch = {key: value for key, value in zip(useful_keys, batch)}
    return batch


def flops_to_string(flops, precision=2):
    return str(round(flops, precision)) + 'K GFLOPs'


def params_to_string(params_num, precision=2):
    if params_num // 10 ** 6 > 0:
        return str(round(params_num / 10 ** 6, precision)) + ' M'
    elif params_num // 10 ** 3:
        return str(round(params_num / 10 ** 3, precision)) + ' k'
    else:
        return str(params_num)


def main():
    args = parse_args()
    set_seed(args.rng_seed)

    accelerator_kwargs = {
        'mixed_precision': args.mixed_precision, 
        'cpu': False, 
    }
    if args.use_wandb:
        accelerator_kwargs['log_with'] = 'wandb'
    accelerator = Accelerator(**accelerator_kwargs)
    if args.use_wandb:
        wandb_kwargs = {'wandb': {'name': args.wandb_run_name}} if args.wandb_run_name is not None else {}
        accelerator.init_trackers(args.wandb_project_name, config=vars(args), init_kwargs=wandb_kwargs)

    dataloader = DATALOADER_DICT[args.task_name](max_length=args.max_length)
    train_data, _ = dataloader.load_train(model_name=args.model_name)
    # Initialize data collator 
    if args.model_name in dataloader.model_types['discriminative']:
        data_collator = DataCollatorWithPadding(dataloader.tokenizer_dict[args.model_name])
    elif args.model_name in dataloader.model_types['generative']: 
        data_collator = DataCollatorForSeq2Seq(dataloader.tokenizer_dict[args.model_name], pad_to_multiple_of=8)
    else:
        raise ValueError(f'Unknown model type for {args.model_name}')
    # Write extra info to args
    args.num_labels, args.num_samples = \
        len(dataloader.dataset_info[dataloader.train_dataset_name]['label_names']), max(train_data['data_idx']) + 1

    # Load model and optimizer
    # Discriminative models use classification heads
    if args.model_name in dataloader.model_types['discriminative']: 
        config = AutoConfig.from_pretrained(args.model_name, num_labels=args.num_labels)
        model = AutoModelForSequenceClassification.from_pretrained(args.model_name, config=config)
    # Generative models use sequence-to-sequence models
    else: 
        model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name)

    train_dataloader = torch_data.dataloader.DataLoader(
        train_data, shuffle=False, batch_size=args.eval_batch_size, collate_fn=data_collator)
    generative_forward = True if args.model_name in dataloader.model_types['generative'] else False
    train_dataloader, model = accelerator.prepare(train_dataloader, model)

    print('Start estimating FLOPs and parameters')
    print(f'Model name: {args.model_name}')
    total_flops_count = 0
    for step, batch in tqdm(enumerate(train_dataloader)):
        batch = batch.to('cuda')
        useful_keys = list(set(['input_ids', 'attention_mask', 'token_type_ids']) & set(batch.keys()))
        if generative_forward:
            useful_keys.append('labels')
        batch = tuple(batch[key] for key in useful_keys)
        with torch.no_grad():
            flops_count, params_count = get_model_complexity_info(
                model, batch, as_strings=False,
                input_constructor=partial(construct_inputs, useful_keys=useful_keys), 
                ost=None, 
                print_per_layer_stat=False
            )
        # Convert to Million GFLOPs
        flops_count = flops_count / 10**9 / 10**3 * 2
        total_flops_count += flops_count 
        if args.use_wandb:
            accelerator.log({'flops_count': flops_count, 'total_flops_count': total_flops_count, 'step': step})

    accelerator.log({'final_flops_count': total_flops_count})
    total_flops_count = flops_to_string(total_flops_count)
    params_count = params_to_string(params_count)
    print('{:<30}  {:<8}'.format('Computational complexity: ', total_flops_count))
    print('{:<30}  {:<8}'.format('Number of parameters: ', params_count))


if __name__ == '__main__':
    main()
