import os
import copy
import json
import logging
from functools import partial
from argparse import ArgumentParser, Namespace

import numpy as np
import torch
import torch.nn
import torch.optim
import datasets
from transformers import TrainingArguments
from torch.utils.data import DataLoader

import train_func
import train_transformers_func
import settings
from helper import utils
from helper.thirdparty.tofu.dataloader import CustomTrainer
from helper.thirdparty.tofu.data_module import custom_data_collator


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("--dataset", type=str, default="mnist",
                        help="Dataset name")
    parser.add_argument("--num_class", type=int,
                        help="Number of classes, assumed classification problem")
    parser.add_argument("--seed", type=int, default=1, 
                        help="Random seed for reproducibility")
    parser.add_argument("--save_dir", type=str, default="outputs/",
                        help="Directory to output results")
    parser.add_argument("--local-rank", type=int, default=-1, 
                        help="For distributed training: local_rank")
    parser.add_argument("--llama", action="store_true", 
                        help="Whether the model is llama")
    parser.add_argument("--train_poison", action="store_true",
                        help="Whether the model will be trained on poisoned data")
    parser.add_argument("--poison_fraction", default=0.0, type=float,
                        help="Fraction of poisoned data")
    parser.add_argument("--learning_rate", default=0.0, type=float,
                        help="Learning rate")
    parser.add_argument("--vit", action="store_true", 
                        help="Whether the model is vit")
    args = parser.parse_args()
    utils.set_seed(args.seed)
    os.makedirs(args.save_dir, exist_ok=True)
    
    # load training configurations
    with open(f"configs/train/{args.dataset}.json", "r") as f:
        config = json.load(f)
        print("train_config:", config)
        config = Namespace(**config)    # so that can access attributes through . operation
    
    if args.learning_rate != 0.0:
        config.training_arguments['learning_rate'] = args.learning_rate
    
    if getattr(args, "local_rank") == -1:
        DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        args.n_gpu = torch.cuda.device_count()
    else:  
        # initialize the distributed backend which will take care of sychronizing nodes/GPUs
        torch.cuda.set_device(getattr(args, "local_rank"))
        DEVICE = torch.device("cuda", getattr(args, "local_rank"))
        torch.distributed.init_process_group(backend="nccl")
        args.n_gpu = 1
    
    logger = logging.getLogger(__name__)
    
    # initialize data and model (and tokenizer if necessary)
    if not (args.llama or args.vit):
        (train_ds, test_ds), model, model_config = settings.prepare_data_and_model(args.dataset)
    elif args.llama or args.hf_trainer:
        (train_ds, test_ds), (model, tokenizer), model_config = settings.prepare_data_and_model(args.dataset, config)
    elif args.vit:
        print("Please use the code in settings/vit_finetune to train a ViT model.")
        raise NotImplementedError
        
    train_ds_size = len(train_ds)
    test_ds_size = len(test_ds)
    print("train_size:", train_ds_size)
    print("test_size:", test_ds_size)

    if args.train_poison:
        if isinstance(train_ds, datasets.Dataset):
            if args.dataset == 'ag_news':   # sample 50% data
                train_ds = train_ds.train_test_split(train_size=0.5)['train']
                train_ds_size = len(train_ds)
                print('sampled_train_size:', train_ds_size)
            train_ds = train_ds.add_column('id', range(train_ds_size))

            poison_size = int(args.poison_fraction * train_ds_size)
            train_ids = np.arange(0, train_ds_size, 1)
            poison_ids = np.random.choice(train_ids, size=poison_size, replace=False)
            poison_ids = set(poison_ids.tolist())

            def label_random_flip(sample, poison_ids):
                if sample['id'] in poison_ids:
                    true_label = sample['label']
                    possible_poisoned_labels = [i for i in range(args.num_class) if i != true_label]
                    poison_label = np.random.choice(possible_poisoned_labels)
                    sample['label'] = poison_label
                return sample
            
            def label_shift_right(sample, poison_ids):
                if sample['id'] in poison_ids:
                    true_label = sample['label']
                    poison_label = true_label + 1
                    if poison_label >= args.num_class:
                        poison_label = 0
                    sample['label'] = poison_label
                return sample
            
            # train_ds = train_ds.map(partial(label_random_flip, poison_ids=poison_ids)) 
            train_ds = train_ds.map(partial(label_shift_right, poison_ids=poison_ids))
            ds_save_path = f'{args.save_dir}/poisoned_ds'
            train_ds.save_to_disk(ds_save_path)
        else:
            poison_size = int(args.poison_fraction * train_ds_size)
            train_ids = np.arange(0, train_ds_size, 1)
            poison_ids = np.random.choice(train_ids, size=poison_size, replace=False)
            poison_ids = set(poison_ids.tolist())

            train_ds = settings.PoisonedDataset(train_ds, poison_ids=poison_ids, num_class=args.num_class) 
            ds_save_path = f'{args.save_dir}/poisoned_ds.pt'
            torch.save(train_ds, ds_save_path)

        ids_save_path = f'{args.save_dir}/unlearn_order.pt'
        torch.save(list(poison_ids), ids_save_path)
        print(f'Poisoned ids are saved to {ids_save_path}')

    # start training 
    if not (args.llama or args.vit):
        train_loader = DataLoader(train_ds, shuffle=True, batch_size=config.train_batch_size)
        test_loader = DataLoader(test_ds, shuffle=False, batch_size=config.eval_batch_size)

        criterion = getattr(torch.nn, config.loss)()
        optimizer_cls = getattr(torch.optim, config.optimizer)
        optimizer = optimizer_cls(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
        lr_scheduler = torch.optim.lr_scheduler.MultiplicativeLR(
            optimizer, lambda step: 0.5 if step % config.lr_update_interval == 0 else 1.0,
        )

        print("no_params (trainable): %d" % sum(p.numel() for p in model.parameters() if p.requires_grad))
        train_func.train(
            model, train_loader, criterion, optimizer, 
            eval_loader=test_loader, num_epochs=config.num_epochs, 
            log_frequency=config.log_frequency, lr_scheduler=lr_scheduler, device=DEVICE,
        )

    elif args.llama:
        if args.dataset == "ag_news":
            training_args = TrainingArguments(output_dir=args.save_dir, **config.training_arguments)
            peft_config = train_transformers_func.get_peft_config(config.peft_config)
            trainer = train_transformers_func.get_sft_trainer(
                model, model_config["task"], train_ds, test_ds, 
                peft_config, tokenizer, config.max_seq_length, training_args,
            )

            print("no_params (trainable): %d" % sum(p.numel() for p in model.parameters() if p.requires_grad))
            train_results = trainer.train() 
            trainer.log_metrics("train", train_results.metrics)
            utils.clear_cache()
            with torch.no_grad():
                eval_metrics = trainer.evaluate()
            trainer.log_metrics("eval", eval_metrics)
            model = trainer.model
        elif args.dataset == "tofu":
            training_args = TrainingArguments(output_dir=args.save_dir, **config.training_arguments)
            trainer = CustomTrainer(
                model=model,
                train_dataset=train_ds,
                eval_dataset=None,
                args=training_args,
                data_collator=custom_data_collator,
            )
            model.config.use_cache = False  # disable KV cache during finetuning
            train_results = trainer.train() 
            # trainer.log_metrics("train", train_results.metrics)
            # utils.clear_cache()
            # with torch.no_grad():
            #     eval_metrics = trainer.evaluate()
            # trainer.log_metrics("eval", eval_metrics)
            # model = trainer.model

    if getattr(args, "local_rank") not in [-1, 0]:
        torch.distributed.barrier()
    
    path = f"{args.save_dir}/original.pt"
    utils.save_model(model, path)