import time
import numpy as np
import importlib
import torch
import torch.distributed as dist
from transformers import DebertaV2ForQuestionAnswering, DebertaV2TokenizerFast, DebertaV2Config, DataCollatorWithPadding
from datasets import load_dataset
import argparse
import random
from peft import LoraConfig, get_peft_model
from myLoraModel.mapping import get_our_peft_model
import os

def setup(rank, world_size, port):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = str(port)
    dist.init_process_group(backend="nccl", world_size=world_size,  rank=rank)


def cleanup():
    dist.destroy_process_group()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--world_size", type=int, default=1, help="number of processes participating in the job")
    parser.add_argument("--rank", type=int, default=0, help="rank of the current process")
    parser.add_argument("--gpu", type=int, default=0, help="the gpu id to run")
    parser.add_argument("--seed", type=int, default=0, help="random seed")
    parser.add_argument("--fp16", action='store_true' , help="float point")
    parser.add_argument("--com_interval", type=int, default=1, help="The interval of communication")
    parser.add_argument("--name_or_path", type=str, default='roberta-base', help="The type of pretrained model")
    parser.add_argument("--dataset", type=str, default='cola', help="The dataset")
    parser.add_argument("--max_seq_length", type=int, default=384, help="The max length of sequence")
    parser.add_argument("--doc_stride", type=int, default=128, help="When splitting up a long document into chunks, how much stride to take between chunks")
    parser.add_argument("--method", type=str, default='lora', help="The optimized method")
    parser.add_argument("--max_grad_norm", type=float, default=1.0, help="clip parameter")
    parser.add_argument("--lr_A", type=float, default=5e-5, help="learning rate for fine-turning lora_A")
    parser.add_argument("--lr_B", type=float, default=5e-5, help="learning rate for fine-turning lora_B")
    parser.add_argument("--lr_scheduler_type", type=str, default="linear", help="learning schedule type")
    parser.add_argument("--nu", type=float, default=5e-4, help="learning rate for fine-turning lora_B")
    parser.add_argument("--beta", type=float, default=0.99, help="momentum parameter for our method")
    parser.add_argument("--batch_size", type=int, default=32, help="batch size for fine-turning")
    parser.add_argument("--eval_batch_size", type=int, default=16, help="batch size for fine-turning")
    parser.add_argument("--num_epochs", type=int, default=2, help="number of epochs for fine-turning")
    parser.add_argument("--save_path", type=str, default='../output', help="the path to save results")
    parser.add_argument("--port", type=int, default=12355, help="the port to communicate")
    parser.add_argument("--lamb", type=float, default=0.1, help="the balanced parameter for hetlora")
    parser.add_argument("--rank_mat", type=int, default=8, help="the rank for lora")
    parser.add_argument("--rank_max", type=int, default=12, help="the rank for lora")
    parser.add_argument("--rank_min", type=int, default=5, help="the rank for lora")
    parser.add_argument("--gamma", type=float, default=0.5, help="the sparsity of hetlora")
    parser.add_argument("--heterogeneity", type=float, default=0.7, help="the dissimilarity of different client")
    parser.add_argument("--inner_loops", type=int, default=10, help="the number of inner loops")
    parser.add_argument("--z_loops", type=int, default=5, help="the number of loops for updating linear system solution")
    parser.add_argument("--num_gpu", type=int, default=4, help="the number of gpu devices")
    parser.add_argument("--process_per_gpu", type=int, default=2, help="the number of gpu devices")
    parser.add_argument("--hessian_q", type=int, default=5, help="the number of q loops")
    parser.add_argument("--clients_per_gpu", type=int, default=2, help="simulate the multiple clients on per gpu")
    parser.add_argument("--com_rounds", type=int, default=100, help="the total communication rounds")
    parser.add_argument("--maml_in_lr", type=float, default=0.01, help="the total communication rounds")
    args = parser.parse_args()

    setup(args.rank, args.world_size, args.port)
    print(f'{args.method} trains on dataset {args.dataset} ')
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    if args.method=='hetlora':
        args.rank_mat = args.rank_min + int((args.rank_max - args.rank_min) * args.rank / args.world_size)
    device = torch.device(f"cuda:{args.gpu}")
    torch.cuda.set_device(device)
    raw_dataset = load_dataset(args.dataset)

    torch_dtype = (
        torch.float32
        if args.fp16
        else (torch.bfloat16 if args.fp16 else torch.float32)
    )
    # if 'roberta' in args.name_or_path:
    model_config = DebertaV2Config.from_pretrained(args.name_or_path, torch_dtype=torch_dtype)
    model = DebertaV2ForQuestionAnswering.from_pretrained(args.name_or_path, config=model_config)
    tokenizer = DebertaV2TokenizerFast.from_pretrained(args.name_or_path)

    save_direct = os.path.join(args.save_path, f'{args.name_or_path}/{args.dataset}')
    if not os.path.exists(save_direct):
        os.makedirs(save_direct)
    date = time.strftime('%Y-%m-%d-%H-%M', time.localtime(time.time()))
    file_name = f'{args.method}_lr_B{args.lr_B}_het{args.heterogeneity}_{date}'
    args.save_path = os.path.join(save_direct, file_name)

    peft_config = LoraConfig(
        task_type="QUESTION_ANS",
        inference_mode=False,
        use_original_init=True,
        target_modules=["query_proj", "key_proj"],
        r=args.rank_mat,
        lora_alpha=8,
        lora_dropout=0.1,
        fan_in_fan_out=False,
    )
    if args.method == 'ours' or args.method == 'ours_one_step' or args.method == 'ours_single_opt'  or args.method == 'ours_maml':
        model = get_our_peft_model(model, peft_config)

    else:
        model = get_peft_model(model, peft_config)
    # print(model)

    # Split the dataset into heterogeneous subsets for each client
    train_datasets = []
    eval_datasets = []
    # Preprocessing the datasets.
    # Preprocessing is slighlty different for training and evaluation.
    column_names = raw_dataset["train"].column_names
    question_column_name = "question" if "question" in column_names else column_names[0]
    context_column_name = "context" if "context" in column_names else column_names[1]
    answer_column_name = "answers" if "answers" in column_names else column_names[2]
    # Padding side determines if we do (question|context) or (context|question).
    pad_on_right = tokenizer.padding_side == "right"

    if args.max_seq_length > tokenizer.model_max_length:
        print(
            f"The max_seq_length passed ({args.max_seq_length}) is larger than the maximum length for the"
            f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
        )
    max_seq_length = min(args.max_seq_length, tokenizer.model_max_length)


    def prepare_train_features(examples):
        # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
        # in one example possible giving several features when a context is long, each of those features having a
        # context that overlaps a bit the context of the previous feature.
        tokenized_examples = tokenizer(
            examples[question_column_name if pad_on_right else context_column_name],
            examples[context_column_name if pad_on_right else question_column_name],
            truncation="only_second" if pad_on_right else "only_first",
            max_length=max_seq_length,
            stride=args.doc_stride,
            return_overflowing_tokens=True,
            return_offsets_mapping=True,
            padding="max_length",
        )

        # Since one example might give us several features if it has a long context, we need a map from a feature to
        # its corresponding example. This key gives us just that.
        sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
        # The offset mappings will give us a map from token to character position in the original context. This will
        # help us compute the start_positions and end_positions.
        offset_mapping = tokenized_examples.pop("offset_mapping")

        # Let's label those examples!
        tokenized_examples["start_positions"] = []
        tokenized_examples["end_positions"] = []

        for i, offsets in enumerate(offset_mapping):
            # We will label impossible answers with the index of the CLS token.
            input_ids = tokenized_examples["input_ids"][i]
            cls_index = input_ids.index(tokenizer.cls_token_id)

            # Grab the sequence corresponding to that example (to know what is the context and what is the question).
            sequence_ids = tokenized_examples.sequence_ids(i)

            # One example can give several spans, this is the index of the example containing this span of text.
            sample_index = sample_mapping[i]
            answers = examples[answer_column_name][sample_index]
            # If no answers are given, set the cls_index as answer.
            if len(answers["answer_start"]) == 0:
                tokenized_examples["start_positions"].append(cls_index)
                tokenized_examples["end_positions"].append(cls_index)
            else:
                # Start/end character index of the answer in the text.
                start_char = answers["answer_start"][0]
                end_char = start_char + len(answers["text"][0])

                # Start token index of the current span in the text.
                token_start_index = 0
                while sequence_ids[token_start_index] != (1 if pad_on_right else 0):
                    token_start_index += 1

                # End token index of the current span in the text.
                token_end_index = len(input_ids) - 1
                while sequence_ids[token_end_index] != (1 if pad_on_right else 0):
                    token_end_index -= 1

                # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index).
                if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
                    tokenized_examples["start_positions"].append(cls_index)
                    tokenized_examples["end_positions"].append(cls_index)
                else:
                    # Otherwise move the token_start_index and token_end_index to the two ends of the answer.
                    # Note: we could go after the last offset if the answer is the last word (edge case).
                    while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
                        token_start_index += 1
                    tokenized_examples["start_positions"].append(token_start_index - 1)
                    while offsets[token_end_index][1] >= end_char:
                        token_end_index -= 1
                    tokenized_examples["end_positions"].append(token_end_index + 1)

        return tokenized_examples

    def prepare_validation_features(examples):
        # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
        # in one example possible giving several features when a context is long, each of those features having a
        # context that overlaps a bit the context of the previous feature.
        tokenized_examples = tokenizer(
            examples[question_column_name if pad_on_right else context_column_name],
            examples[context_column_name if pad_on_right else question_column_name],
            truncation="only_second" if pad_on_right else "only_first",
            max_length=max_seq_length,
            stride=args.doc_stride,
            return_overflowing_tokens=True,
            return_offsets_mapping=True,
            padding="max_length",
        )

        # Since one example might give us several features if it has a long context, we need a map from a feature to
        # its corresponding example. This key gives us just that.
        sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")

        # For evaluation, we will need to convert our predictions to substrings of the context, so we keep the
        # corresponding example_id and we will store the offset mappings.
        tokenized_examples["example_id"] = []

        for i in range(len(tokenized_examples["input_ids"])):
            # Grab the sequence corresponding to that example (to know what is the context and what is the question).
            sequence_ids = tokenized_examples.sequence_ids(i)
            context_index = 1 if pad_on_right else 0

            # One example can give several spans, this is the index of the example containing this span of text.
            sample_index = sample_mapping[i]
            tokenized_examples["example_id"].append(examples["id"][sample_index])

            # Set to None the offset_mapping that are not part of the context so it's easy to determine if a token
            # position is part of the context or not.
            tokenized_examples["offset_mapping"][i] = [
                (o if sequence_ids[k] == context_index else None)
                for k, o in enumerate(tokenized_examples["offset_mapping"][i])
            ]

        return tokenized_examples


    titles = [example['title'] for example in raw_dataset['train']]
    unique_titles = set(titles)

    # Number of unique titles
    num_unique_titles = len(unique_titles)
    train_data = raw_dataset["train"]
    test_data = raw_dataset["validation"]

    data_collator = (DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8 if args.fp16 else None))

    dist.barrier()
    if args.rank == 0:
        print(f'Construct the heterogeneous data for each client')
    train_indices = {}
    eval_indices = {}
    for t in unique_titles:
        train_indices[t] = []
    for i, d in enumerate(train_data):
        train_indices[d['title']].append(i)

    train_label_proportions = torch.tensor(
        [float(len(train_indices[i])) for i in unique_titles]
    )

    train_ratio = 0.8
    train_dataset_size = int(torch.sum(train_label_proportions) * train_ratio)
    eval_dataset_size = torch.sum(train_label_proportions) - train_dataset_size
    train_label_proportions /= train_dataset_size
    # eval_dataset_size = torch.sum(eval_label_proportions)
    # eval_label_proportions /= eval_dataset_size
    for t in unique_titles:
        random.shuffle(train_indices[t])
    # divide samples from each label into iid pool and non-iid pool. Note that samples
    # in iid pool are shuffled while samples in non-iid pool are sorted by label.
    iid_pool = []
    non_iid_pool = []
    for i in unique_titles:
        iid_split = int((1.0 - args.heterogeneity) * len(train_indices[i]))
        iid_pool += train_indices[i][:iid_split]
        non_iid_pool += train_indices[i][iid_split:]
    random.shuffle(iid_pool)
    # Allocate iid and non-iid samples to each worker.
    iid_start = 0
    non_iid_start = 0
    partition_size_train = int(train_dataset_size // args.world_size)
    partition_size_eval = int(eval_dataset_size // args.world_size)
    if args.rank == 0:
        print(f'The training size: {train_dataset_size}, test size: {eval_dataset_size}')
    train_worker_idxs = [[] for _ in range(args.world_size)]
    eval_worker_idxs = [[] for _ in range(args.world_size)]
    train_lower_idxs = [[] for _ in range(args.world_size)]
    for j in range(args.world_size):
        num_iid_train = int((1.0 - args.heterogeneity) * partition_size_train)
        num_non_iid_train = partition_size_train - num_iid_train
        train_worker_idxs[j] += iid_pool[iid_start: iid_start + num_iid_train]
        train_worker_idxs[j] += non_iid_pool[non_iid_start: non_iid_start + num_non_iid_train]
        train_lower_idxs[j] += iid_pool[iid_start: iid_start + num_iid_train]
        train_lower_idxs[j] += non_iid_pool[non_iid_start: non_iid_start + num_non_iid_train]
        iid_start += num_iid_train
        non_iid_start += num_non_iid_train
        random.shuffle(train_worker_idxs[j])
        random.shuffle(train_lower_idxs[j])

        num_iid_eval = int((1.0 - args.heterogeneity) * partition_size_eval)
        num_non_iid_eval = partition_size_eval - num_iid_eval
        eval_worker_idxs[j] += iid_pool[iid_start: iid_start + num_iid_eval]
        eval_worker_idxs[j] += non_iid_pool[non_iid_start: non_iid_start + num_non_iid_eval]
        iid_start += num_iid_eval
        non_iid_start += num_non_iid_eval
        random.shuffle(eval_worker_idxs[j])


    print(f'Rank: {args.rank}  training dataset size: {len(train_worker_idxs[args.rank])},  test dataset size: {len(eval_worker_idxs[args.rank])}')
    train_dataset = train_data.select(train_worker_idxs[args.rank])
    train_lower = train_data.select(train_lower_idxs[args.rank])
    eval_dataset = train_data.select(eval_worker_idxs[args.rank])
    train_datasets = train_dataset.map(prepare_train_features, batched=True, remove_columns=column_names)
    eval_datasets = eval_dataset.map(prepare_validation_features, batched=True, remove_columns=column_names)
    dist.barrier()
    method = importlib.import_module('./' + args.method)
    model = method.Model(args, model, dist)
    if args.method == 'pf2lora':
        train_lower_data = train_lower.map(prepare_train_features, batched=True, remove_columns=column_names)
        model.train(train_datasets, eval_datasets, train_lower_data, eval_dataset, tokenizer)
    else:
        model.train(train_datasets, eval_datasets, eval_dataset, tokenizer)
    cleanup()
