import torch
import numpy as np

from settings import train_func
from helper import utils

def random_labels(model,
                  forget_set,
                  retain_set,
                  config,
                  **kwargs,
                  ):

    if config.llama:
        trainer_init_func = kwargs.pop("trainer_init_func")
        trainer_init_kwargs = kwargs.pop("trainer_init_kwargs")
        assert trainer_init_kwargs.task_type == "classification", \
            f"Only support random labels in classification task. Current task is {trainer_init_kwargs.task_type}."
        
        noisy_set = []
        for id in range(len(forget_set)):
            sample = forget_set[id].copy()
            sample_id = np.random.randint(len(retain_set))
            noisy_sample = retain_set[sample_id] 
            assert "label" in sample and "label" in noisy_sample
            sample["label"] = noisy_sample["label"]
            noisy_set.append(sample)
        
        trainer_init_kwargs.model = model
        trainer = trainer_init_func(**vars(trainer_init_kwargs))
        trainer.train_dataset = noisy_set   # must place after trainer to bypass invalid processing of list dataset
        trainer.train()
        model = trainer.model
        trainer.train_dataset = forget_set    # can set to any non-list-typed dataset

    else:
        noisy_set = []
        for id in range(len(forget_set)):
            sample = forget_set[id]
            sample_id = np.random.randint(len(retain_set))
            noisy_sample = retain_set[sample_id] 
            noisy_set.append((sample[0], noisy_sample[1]))
        
        train_loader = utils.get_dataloader(noisy_set,
                                            shuffle=True,
                                            batch_size=config.train_batch_size)
        loss_fn = getattr(torch.nn, config.loss)()
        optimizer_cls = getattr(torch.optim, config.optimizer)
        optimizer = optimizer_cls(model.parameters(),
                                  lr=config.learning_rate,
                                  weight_decay=config.weight_decay)
        lr_scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optimizer, 
                                                                 lambda step: 0.5 if step % config.lr_update_interval == 0 else 1.0)
        train_func.train(model,
                         train_loader,
                         loss_fn,
                         optimizer,
                         eval_dataloader=None,
                         num_epochs=config.num_epochs,
                         log_frequency=config.log_frequency,
                         lr_scheduler=lr_scheduler,
                         device=config.device)
    # elif cfg.llama:
    #     from trl import SFTTrainer
    #     noisy_set = []
    #     for i in range(len(df_set)):
    #         sample = df_set[i]
    #         sample_id = numpy.random.randint(len(dr_set))
    #         noisy_sample = dr_set[sample_id] 
    #         sample['labels'] = noisy_sample['labels']
    #         noisy_set.append(sample)
            
    #     cfg.training_args.num_train_epochs = 1
    #     trainer = SFTTrainer(
    #         model=model,
    #         train_dataset=noisy_set,
    #         # peft_config=cfg.peft_config,
    #         max_seq_length=1024,
    #         tokenizer=cfg.tokenizer,
    #         dataset_text_field="text",
    #         args=cfg.training_args,
    #     )
    #     trainer.train()
    