import random
import os
import gc

import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, Subset

from settings import train_func


def set_seed(seed):
    """
    References: https://wandb.ai/sauravmaheshkar/RSNA-MICCAI/reports/How-to-Set-Random-Seeds-in-PyTorch-and-Tensorflow--VmlldzoxMDA2MDQy
    """
    
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ["PYTHONHASHSEED"] = str(seed)
    print(f"Random seed set as {seed}")

def get_dataloader(dataset, *args, **kwargs):
    # if "num_workers" not in kwargs:
    #     kwargs["num_workers"] = 4   # increase data loading speed
    if "batch_size" in kwargs:
        kwargs["batch_size"] = min(kwargs["batch_size"], len(dataset))
    res = DataLoader(dataset, *args, **kwargs)
    return res

def sample(dataset, sample_ids):
    if isinstance(dataset, Subset):
        mapped_sample_ids = [dataset.indices[id] for id in sample_ids]
        res = Subset(dataset.dataset, mapped_sample_ids)
    else:
        res = Subset(dataset, sample_ids)
    return res

def shard(dataset, num_shards, index):
    if not 0 <= index < num_shards:
        raise ValueError("index should be in [0, num_shards-1]")
    indices = np.arange(index, len(dataset), num_shards)
    return sample(dataset, indices)

def merge_dict(from_dict, to_dict):
    res = from_dict.copy()
    for k in to_dict.keys():
        if k in from_dict and isinstance(from_dict[k], dict) and isinstance(to_dict[k], dict):
            res[k] = merge_dict(from_dict[k], to_dict[k])
        else:
            res[k] = to_dict[k]
    return res

def get_trainable_state_dict(model):
    trainable_layers = [name for name, param in model.named_parameters() if param.requires_grad]
    res = {}
    for k, v in model.state_dict().items():
        if k in trainable_layers:
            res[k] = v
    return res

def save_trainable_model(model, path):
    res = {}
    for name, param in model.named_parameters():
        if param.requires_grad:
            res[name] = param.detach()
    torch.save(res, path)
    return


def load_trainable_model(model, path):
    checkpoint = torch.load(path)
    trainable_layers = set(checkpoint.keys())
    all_layers = set(model.state_dict().keys())
    num_match_layers = len(trainable_layers.intersection(all_layers))
    print('Load trainable parameters for {}/{} layers'.format(num_match_layers, len(all_layers)))
    model.load_state_dict(checkpoint, strict=False)
    return model

def clear_cache():
    torch.cuda.empty_cache()
    gc.collect()

def convert_torch_to_numpy(tensor):
    try:
        res = tensor.numpy()
    except TypeError as error:
        # avoid typecast error from torch.bfloat16 to numpy
        res = tensor.float().numpy()
    return res

def eval_unlearning(
    model, accu_forget_set, retain_set, test_set, 
    config, trainer_init_func=None, trainer_init_kwargs=None, device=None,
):
    res = {}
    if config.llama:
        # trainer_init_func = kwargs.pop("trainer_init_func")
        # trainer_init_kwargs = kwargs.pop("trainer_init_kwargs")
        trainer_init_kwargs.model = model
        trainer = trainer_init_func(**vars(trainer_init_kwargs))
        if trainer_init_kwargs.task_type == "classification":
            metrics = ["eval_accuracy"]
        elif trainer_init_kwargs.task_type == "summarization":
            metrics = ["eval_rouge"]
        else:
            raise NotImplementedError
        for metric in metrics:
            res[f"accu_forget_{metric}"] = trainer.evaluate(accu_forget_set)[metric]
            clear_cache()
            res[f"retain_{metric}"] = trainer.evaluate(retain_set)[metric]
            clear_cache()
            res[f"test_{metric}"] = trainer.evaluate(test_set)[metric]
            clear_cache()

    else:
        print("Evaluating on accumulated forget set...")
        loader = DataLoader(accu_forget_set, shuffle=False, batch_size=config.eval_batch_size)
        res["accu_forget_acc"] = train_func.evaluate(model, loader, device=device)

        print("Evaluating on retain set...")
        loader = DataLoader(retain_set, shuffle=False, batch_size=config.eval_batch_size)
        res["retain_acc"] = train_func.evaluate(model, loader, device=device)

        print("Evaluating on test set...")
        loader = DataLoader(test_set, shuffle=False, batch_size=config.eval_batch_size)
        res["test_acc"] = train_func.evaluate(model, loader, device=device)

    return res

class CustomDataset(Dataset):
    def __init__(self, dataframe, tokenizer, max_length=1024):
        self.tokenizer = tokenizer
        self.data = dataframe
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        formatted_example = self.format_example(row['dialogue'], row['summary'])
        inputs = self.tokenizer.encode_plus(formatted_example, max_length=self.max_length, truncation=True, padding="max_length", return_tensors="pt")
        input_ids = inputs['input_ids'].squeeze()
        attention_mask = inputs['attention_mask'].squeeze()
        return {'input_ids': input_ids, 'attention_mask': attention_mask}

    @staticmethod
    def format_example(dialogue, summary):
        return f"""<s>### Instruction:
        You are a helpful, respectful, and honest assistant. \
        Your task is to summarize the following dialogue. \
        Your answer should be based on the provided dialogue only.

        ### Dialogue:
        {dialogue}

        ### Summary:
        {summary} </s>"""

def save_model(model, path, save_trainable: bool = False):
    if save_trainable:
        state_dict = get_trainable_state_dict(model)
    else:
        state_dict = model.state_dict() 
    torch.save({"state_dict": state_dict}, path)
    print(f"Model is saved to {path}")