# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the GNU General Public License version 3.

from typing import Tuple
import os
import sys
import torch
import fire
import time
import json
import random
from pathlib import Path

from fairscale.nn.model_parallel.initialize import initialize_model_parallel
from tqdm import tqdm
from llama import ModelArgs, Transformer, Tokenizer, LLaMA, FunctionLM, EmbedGenLM
import wandb

# set random seed
random.seed(0)
torch.manual_seed(0)

def setup_model_parallel() -> Tuple[int, int]:
    local_rank = int(os.environ.get("LOCAL_RANK", -1))
    world_size = int(os.environ.get("WORLD_SIZE", -1))

    torch.distributed.init_process_group("nccl")
    initialize_model_parallel(world_size)
    torch.cuda.set_device(local_rank)

    # seed must be the same in all processes
    torch.manual_seed(1)
    return local_rank, world_size


def load(ckpt_dir: str, tokenizer_path: str, local_rank: int, world_size: int, api_texts, feature, length_map, shuffle: bool, generation_args: dict, paraphrase_dict=None) -> EmbedGenLM:
    start_time = time.time()
    checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
    # print(checkpoints)
    assert (
        world_size == len(checkpoints)
    ), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {world_size}"
    ckpt_path = checkpoints[local_rank]
    print("Loading")
    checkpoint = torch.load(ckpt_path, map_location="cpu")
    with open(Path(ckpt_dir) / "params.json", "r") as f:
        params = json.loads(f.read())

    model_args: ModelArgs = ModelArgs(max_seq_len=1024, max_batch_size=1, **params)
    tokenizer = Tokenizer(model_path=tokenizer_path)
    model_args.vocab_size = tokenizer.n_words
    torch.set_default_tensor_type(torch.cuda.HalfTensor)
    model = Transformer(model_args).cuda().half()
    torch.set_default_tensor_type(torch.FloatTensor)
    model.load_state_dict(checkpoint, strict=False)
    funcmodel = EmbedGenLM(model, tokenizer, feature=feature, length_map=length_map, d_model=model_args.dim, shuffle=shuffle, generation_args=generation_args, api_texts=api_texts, paraphrase_dict=paraphrase_dict)#, load_path="func_dict_gsm.pt")
    # generator = LLaMA(model, tokenizer)
    print(f"Loaded in {time.time() - start_time:.2f} seconds")
    return funcmodel


def main(ckpt_dir: str, tokenizer_path: str, save_file: str = "saved_models/v1-embgen", lr: float = 1e-5, data_version: str = "1", n_neg_samples: int = 8, shuffle=False, gen_model="transformer", n_head=4, n_layer=2, warmup_steps=-1, paraphrase=False):

    # set random seed
    random.seed(0)
    torch.manual_seed(0)


    local_rank, world_size = setup_model_parallel()
    if local_rank > 0:
        sys.stdout = open(os.devnull, 'w')

    if gen_model != "transformer":
        n_layer = 0
        n_head = 0

    print(data_version)
    if isinstance(data_version, int):
        data_version = str(data_version)
    if data_version == "1":
        input_file = "data/function/gpt_pretrain_v1.json"
    elif data_version == "1+kamel":
        input_file = "data/function/gpt_pretrain_v1+kamel.json"
    elif data_version == "1+kamel-question":
        input_file = "data/function/gpt_pretrain_v1+kamel-question.json"    
    else:
        raise NotImplementedError

    data = json.load(open(input_file, "r"))["data"]
    apis = json.load(open(input_file, "r"))["api_info"]
    n_apis = len(apis)

    train_data = data["train"]
    in_domain_test_data = data["test_examples"]
    out_domain_test_data = data["test_apis"]
    # n_apis=213, 
    # feature_file: str = None, length_map_file: str = None, 
    # feature = torch.load("outputs/api_feature_v1.pt")
    api_texts = [i for i, j in apis]
    # api_name = [a[0] for a in apis]
    length_map = [a[1] for a in apis]
    n_test_apis = 100
        
    if paraphrase:
        paraphrase_dict = json.load(open("data/function/paraphrase.json", "r"))
    else:
        paraphrase_dict = None

    print("n_apis", n_apis)
    print("n_examples", len(train_data + in_domain_test_data + out_domain_test_data))


    generation_args = {
        "model": gen_model,
        "n_head": n_head,
        "n_layer": n_layer,
        "n_apis": n_apis,
    }

    if local_rank == 0:
        wandb.init(project="func-pretrain-llama-v1", name=f"lr-{lr}_warmup-{warmup_steps}_shuffle-{shuffle}_size-{world_size}_gen-{gen_model}:fixed_head-{n_head}_layer-{n_layer}_neg-{n_neg_samples}_dverision-{data_version}-paraphrase-{paraphrase}", config={"lr": lr, "warm": warmup_steps, "shuffle": shuffle, "size": world_size, "gen_model": gen_model, "n_head": n_head, "n_layer": n_layer, "n_neg_samples": n_neg_samples, "data_version": data_version, "paraphrase": paraphrase})
        # wandb.init(project="opt", name=save_name)

    embedgenmodel = load(ckpt_dir, tokenizer_path, local_rank, world_size, api_texts, None, length_map, shuffle=shuffle, generation_args=generation_args, paraphrase_dict=paraphrase_dict)
    
    # find the longest common prefix of first two prompts
    # prefix = os.path.commonprefix(prompts[:2])
    # print(f"Common prefix: {prefix}")
    # print(f"There are {len(prompts)} prompts in total.")

    # only update tokens with gradients required
    optimizer = torch.optim.Adam([p for p in embedgenmodel.parameters() if p.requires_grad], lr=lr)
    # print all tunable parameters
    for name, param in embedgenmodel.named_parameters():
        if param.requires_grad:
            print(name)


    # warm up
    if warmup_steps > 0:
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda e: min((e + 1) / warmup_steps, 1.0))

    from collections import defaultdict
    # results = defaultdict(list)
    best_p = 0
    for e in range(5):
        
        train_results = defaultdict(list)
        embedgenmodel.train()

        random.shuffle(train_data)
        for case_idx, d in tqdm(enumerate(train_data)):
            # results = generator.generate([prompt], max_gen_len=512, temperature=temperature, top_p=top_p)
            optimizer.zero_grad()
            sample_size = min(n_neg_samples + 1, n_apis - n_test_apis)
            neg_samples = random.sample([i + n_test_apis for i in range(n_apis - n_test_apis)], sample_size)
            if d["api_idx"] in neg_samples:
                neg_samples.remove(d["api_idx"])
            else:
                neg_samples = neg_samples[:-1]

            loss, result = embedgenmodel.get_loss([{**d, "neg_samples": neg_samples}])
            loss.backward()
            optimizer.step()
            if warmup_steps > 0:
                scheduler.step()

            for i, r in result.items():
                train_results[i].append(r.item())
            
            if (case_idx + 1) % 500 == 0:
                # print(results["tp"][0].shape)
                # (4,)
                # train_results = {k: round((sum(results[k]) / len(results[k])), 2) for k in results.keys()}
                # results = defaultdict(list)
            
                if local_rank == 0:
                    wandb.log({"loss": loss.item()})
                    wandb.log({f"train_{k}": round((sum(v) / len(v)), 2) for k, v in train_results.items()})
                    print({f"train_{k}": round((sum(v) / len(v)), 2) for k, v in train_results.items()})

                train_results = defaultdict(list)

        embedgenmodel.eval()
        embedgenmodel.update_cache()
        id_results = defaultdict(list)
        for case_idx, d in tqdm(enumerate(in_domain_test_data)):
            with torch.no_grad():
                """
                sample_size = min(n_neg_samples + 1, n_apis - n_test_apis)
                neg_samples = random.sample([i + n_test_apis for i in range(n_apis - n_test_apis)], sample_size)
                if d["api_idx"] in neg_samples:
                    neg_samples.remove(d["api_idx"])
                else:
                    neg_samples = neg_samples[:-1]
                """
                neg_samples = [i + n_test_apis for i in range(n_apis - n_test_apis)]
                neg_samples.remove(d["api_idx"])                
                # print(neg_samples[:10])
                # print(len(neg_samples))
                # print(d["api_idx"])
                # print(neg_samples[d["api_idx"] - 2: d["api_idx"] + 2])
                loss, result = embedgenmodel.get_loss([{**d, "neg_samples": neg_samples}], use_cache=True)
                
                # loss, result = embedgenmodel.get_loss([{**d, "neg_samples": random.sample([i + n_test_apis for i in range(n_apis - n_test_apis)], n_neg_samples)}])
                for i, r in result.items():
                    id_results[i].append(r.item())
        
        # id_results = {k: round((sum(results[k]) / len(results[k])), 2) for k in results.keys()}
        if local_rank == 0:
            wandb.log({f"id_{k}": round((sum(v) / len(v)), 2) for k, v in id_results.items()})
        # print("In domain test results", id_results)

        ood_results = defaultdict(list)
        for case_idx, d in tqdm(enumerate(out_domain_test_data)):
            # embedgenmodel.eval()
            with torch.no_grad():
                """
                sample_size = min(n_neg_samples + 1, n_test_apis)
                neg_samples = random.sample([i for i in range(n_test_apis)], sample_size)
                if d["api_idx"] in neg_samples:
                    neg_samples.remove(d["api_idx"])
                else:
                    neg_samples = neg_samples[:-1]
                """
                neg_samples = [i for i in range(n_test_apis)]
                neg_samples.remove(d["api_idx"])
                loss, result = embedgenmodel.get_loss([{**d, "neg_samples": neg_samples}], use_cache=True)

                # loss, result = embedgenmodel.get_loss([{**d, "neg_samples": random.sample([i for i in range(n_test_apis)], n_neg_samples)}])
                for i, r in result.items():
                    ood_results[i].append(r.item())

        # ood_results = {k: round((sum(results[k]) / len(results[k])), 2) for k in results.keys()}
        if local_rank == 0:
            wandb.log({f"ood_{k}": round((sum(v) / len(v)), 2) for k, v in ood_results.items()})
            if sum(ood_results["tp"])/sum(ood_results["pred_funcs"]) > best_p:
                prec = sum(ood_results["tp"]) / sum(ood_results["pred_funcs"])
                if prec > best_p:
                    best_p = prec
                    save_name = f"{save_file}/lr-{lr}_shuffle-{shuffle}_size-{world_size}_gen-{gen_model}_head-{n_head}_layer-{n_layer}_neg-{n_neg_samples}_dversion-{data_version}_paraphrase-{paraphrase}_epoch-{e}.pt"
                    embedgenmodel.save_gen_model(save_name)

            # print(f"Loss: {loss.item()}")
            # print(results)
            # with open("results/gsm-8k-30B.txt", "a") as f:
            #     for result in results:
            #         f.write(result.replace("\n", "\\n") + "\n")

            # with open(save_file, "a") as f:
            #     for result in results:
            #         f.write(result.replace(prefix, "").replace("\n", "\\n") + "\n")
    # for result in results: 
    #     print(result)
    #     print("\n==================================\n")
if __name__ == "__main__":
    fire.Fire(main)