# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the GNU General Public License version 3.

from typing import Tuple
import os
import sys
import torch
import fire
import time
import json

from fairscale.nn.model_parallel.initialize import initialize_model_parallel

from lavin.eval_model import ModelArgs, Transformer
# from lavin.model import ModelArgs, Transformer
from lavin.tokenizer import Tokenizer
from lavin.generator import LaVIN_Generator
from lavin.mm_adapter import set_MMAdapter,set_Clip_Adapter,set_lora
from util.base_prompt import build_prompt
from dataclasses import dataclass
import re
import random

import warnings
import pandas as pd
from PIL import Image

from torchvision.transforms import transforms
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD

from pathlib import Path
import torch.nn as nn
import fairscale.nn.model_parallel.initialize as fs_init
import torch.distributed as dist
from util.apply_delta import apply_model_delta_online
from util.quantization import omniquant,quant_model_omni
from util.int_linear import QuantLinear
from train import get_args_parser
import argparse
warnings.filterwarnings('ignore')

@dataclass
class PromptArgs:
    prompt_format='QCM-ALE'
    use_caption=True
    options=["A", "B", "C", "D", "E"]

def setup_model_parallel() -> Tuple[int, int]:
    local_rank = int(os.environ.get("LOCAL_RANK", -1))
    world_size = int(os.environ.get("WORLD_SIZE", -1))

    torch.distributed.init_process_group("nccl")
    initialize_model_parallel(world_size)
    torch.cuda.set_device(local_rank)

    # seed must be the same in all processes
    torch.manual_seed(1)
    return local_rank, world_size

def _load_and_redistribute_checkpoint(llama_model_path, model_name):

    with open(Path(llama_model_path) / model_name / 'params.json') as f:
        params = json.load(f)
    tokenizer = Tokenizer(model_path=str(Path(llama_model_path) / 'tokenizer.model'))
    print('Using model path: %s, model_name: %s' % (llama_model_path, model_name))
    if model_name=='7B':
        checkpoint = torch.load(llama_model_path + model_name + '/consolidated.00.pth', map_location="cpu")
        return checkpoint, tokenizer, params

    checkpoints = (Path(llama_model_path) / model_name).glob('*.pth')
    checkpoints = sorted(checkpoints)

    mp_world_size = fs_init.get_model_parallel_world_size()
    mp_rank = fs_init.get_model_parallel_rank()
    if mp_world_size == len(checkpoints):
        print('same number of shards of checkpoints and training, loading directly...')
        dist.barrier()
        print('[rank=%d, mp_rank=%d] loading from %s' % (dist.get_rank(), mp_rank, checkpoints[mp_rank]))
        checkpoint = torch.load(checkpoints[mp_rank], map_location='cpu')
    else:
        print('different number of shards of checkpoints and training, redistributing...')
        if dist.get_rank() == 0:
            loaded = []
            for x in checkpoints:
                print('loading from', x)
                loaded.append(torch.load(x, map_location='cpu'))

            full_state_dict = {}
            split_dims = {}

            def add_weight_with_split_dim(name, dim):
                if dim < 0:  # bcast without split
                    full_state_dict[name] = loaded[0][name].clone()
                else:
                    full_state_dict[name] = torch.cat([x[name] for x in loaded], dim=dim)
                for x in loaded:
                    del x[name]
                split_dims[name] = dim

            add_weight_with_split_dim('tok_embeddings.weight', 1)
            add_weight_with_split_dim('norm.weight', -1)
            add_weight_with_split_dim('output.weight', 0)
            for i in range(params['n_layers']):
                print('gathering layer %d of %d' % (i, params['n_layers']))
                layer_prefix = f'layers.{i}.'
                bcast_names = [
                    'attention_norm.weight',
                    'ffn_norm.weight',
                ]
                column_parallel_names = [
                    'attention.wq.weight',
                    'attention.wk.weight',
                    'attention.wv.weight',
                    'feed_forward.w1.weight',
                    'feed_forward.w3.weight',
                ]
                row_parallel_names = [
                    'attention.wo.weight',
                    'feed_forward.w2.weight',
                ]
                for key in bcast_names:
                    add_weight_with_split_dim(layer_prefix + key, -1)
                for key in column_parallel_names:
                    add_weight_with_split_dim(layer_prefix + key, 0)
                for key in row_parallel_names:
                    add_weight_with_split_dim(layer_prefix + key, 1)

            full_state_dict_meta = dict((k, v.shape) for k, v in full_state_dict.items())
            dist.broadcast_object_list([full_state_dict_meta, split_dims], src=0)

        else:  # dist.get_rank() != 0
            recv_objs = [None, None]
            dist.broadcast_object_list(recv_objs, src=0)
            full_state_dict_meta, split_dims = recv_objs

        local_state_dict = {}
        for k in sorted(full_state_dict_meta.keys()):
            print('redistributing weights: %s' % k)
            if dist.get_rank() == 0:
                value = full_state_dict[k].cuda().half()
                del full_state_dict[k]
            else:
                value = torch.empty(full_state_dict_meta[k], device='cuda', dtype=torch.half)
            dist.broadcast(value, src=0)
            value = value.cpu()
            if split_dims[k] < 0:
                local_state_dict[k] = value
            else:
                dim = split_dims[k]
                assert dim >= 0 and dim < value.ndim and value.size(dim) % mp_world_size == 0
                shard_size = value.size(dim) // mp_world_size
                shard_st, shard_ed = shard_size * mp_rank, shard_size * (mp_rank + 1)
                # TODO: make more general
                if dim == 0:
                    value = value[shard_st: shard_ed]
                elif dim == 1:
                    value = value[:, shard_st: shard_ed]
                else:
                    raise NotImplementedError()
                local_state_dict[k] = value.clone()

        checkpoint = local_state_dict

    return checkpoint, tokenizer, params



def get_acc_with_contion(res_pd, key, values):
    if isinstance(values, list):
        total_pd = res_pd[res_pd[key].isin(values)]
    else:
        total_pd = res_pd[res_pd[key] == values]
    correct_pd = total_pd[total_pd['true_false'] == True]
    acc = "{:.2f}".format(len(correct_pd) / len(total_pd) * 100)
    return acc


def get_scores(result_file, data_file):
    # read result file
    results = json.load(open(result_file))
    num = len(results)
    assert num == 4241

    sqa_data = json.load(open(data_file))

    # construct pandas data
    sqa_pd = pd.DataFrame(sqa_data).T
    res_pd = sqa_pd[sqa_pd['split'] == 'test']  # test set

    # update data
    for index, row in res_pd.iterrows():

        res_pd.loc[index, 'no_context'] = True if (not row['hint'] and not row['image']) else False
        res_pd.loc[index, 'has_text'] = True if row['hint'] else False
        res_pd.loc[index, 'has_image'] = True if row['image'] else False
        res_pd.loc[index, 'has_text_image'] = True if (row['hint'] and row['image']) else False

        label = row['answer']
        pred = int(results[index])
        res_pd.loc[index, 'pred'] = pred
        res_pd.loc[index, 'true_false'] = (label == pred)

    # accuracy scores
    acc_average = len(res_pd[res_pd['true_false'] == True]) / num * 100

    scores = {
        'acc_natural':
        get_acc_with_contion(res_pd, 'subject', 'natural science'),
        'acc_social':
        get_acc_with_contion(res_pd, 'subject', 'social science'),
        'acc_language':
        get_acc_with_contion(res_pd, 'subject', 'language science'),
        'acc_has_text':
        get_acc_with_contion(res_pd, 'has_text', True),
        'acc_has_image':
        get_acc_with_contion(res_pd, 'has_image', True),
        'acc_no_context':
        get_acc_with_contion(res_pd, 'no_context', True),
        'acc_grade_1_6':
        get_acc_with_contion(res_pd, 'grade', ['grade1', 'grade2', 'grade3', 'grade4', 'grade5', 'grade6']),
        'acc_grade_7_12':
        get_acc_with_contion(res_pd, 'grade', ['grade7', 'grade8', 'grade9', 'grade10', 'grade11', 'grade12']),
        'acc_average':
        "{:.2f}".format(acc_average),
    }

    return scores


def print_scores(scores):
    latex_output = ""
    for key, score in scores.items():
        print(f"{key[4:]}: \t{score}")
        latex_output += f"& {score} "
    latex_output += "\\\\"
    print(latex_output)

def load(
    ckpt_dir: str,
    llm_model:str,
    tokenizer_path: str,
    adapter_path: str,
    local_rank: int,
    world_size: int,
    max_seq_len: int,
    max_batch_size: int,
    adapter_type: str,
    adapter_dim:int,
    adapter_scale:float,
    hidden_proj:int,
    visual_adapter_type: str,
    temperature: float,
    use_vicuna: bool,
    bits: str='16bits',
    cpu_load:bool=False,
    args: argparse.ArgumentParser=None,
) -> LaVIN_Generator:
    print("test args: ",args)
    start_time = time.time()
    checkpoint, tokenizer, params = _load_and_redistribute_checkpoint(ckpt_dir, llm_model)

    print("Loading")
    adapter_checkpoint = torch.load(adapter_path, map_location="cpu")


    model_args: ModelArgs = ModelArgs(
        max_seq_len=max_seq_len, max_batch_size=max_batch_size,hidden_proj=hidden_proj, **params
    )
    model_args.vocab_size = tokenizer.n_words

    if cpu_load:
        #cpu load is slow, but is freindly for GPU with limited memory.
        torch.set_default_tensor_type(torch.HalfTensor)
    else:
        torch.set_default_tensor_type(torch.cuda.HalfTensor)

    model = Transformer(model_args)
    #delete language encoder
    del model.backbone.transformer

    torch.set_default_tensor_type(torch.FloatTensor)

    if args.bits in ['4bit','8bit']:
        if args.lwc is False:
            from util.quantization import quant_model_bnb
            model.layers=quant_model_bnb(model.layers,quant_bit='4bit')
            print("bnb quantization finish ...")
            model.load_state_dict(checkpoint, strict=False)
        else:
            print("==== weight quant params ====")
            print(args.weight_quant_params)
            print("==== act quant params ====")
            print(args.act_quant_params)
            model.load_state_dict(checkpoint, strict=False)
            model.layers = quant_model_omni(model.layers, args)
            if args.quant_resume:
                omni_parameters = torch.load(args.quant_resume,map_location="cpu")
            else:
                omni_parameters = {}
            print("load quant module...")
            layers = model.layers
            for i in range(len(layers)):
                layer = layers[i]
                if i > args.start_layer:
                    for name, module in layer.named_modules():
                        if isinstance(module, QuantLinear):
                            module.weight_quantizer.register_parameter('dynamic_scale',
                                     nn.Parameter(torch.ones((module.weight_quantizer.group_num,1))))
                layer.load_state_dict(omni_parameters[i],strict=False)
                for name, module in layer.named_modules():
                    if isinstance(module, QuantLinear):
                        module.weight_quantizer.mode = 'test'
                        module.weight_quantizer.per_token_dynamic_calibration(module.weight)
                        module.weight_quantizer.register_scales_and_zeros()
            print("omniquant finish ...")
            
            if args.scaling_resume:
                print("loading scaling parameters")
                ckpt, state_dict= torch.load(args.scaling_resume)['model'], {}
                for key in ckpt:
                    state_dict[key.replace('module.','')]=ckpt[key]  
                model.load_state_dict(state_dict, strict=False)

            for name, module in model.named_modules():
                if isinstance(module, QuantLinear):
                    module.weight = module.weight_quantizer(module.weight)
                    module.use_temporary_parameter=False
                    module.use_weight_quant = False
                
            
    else:
        model.load_state_dict(checkpoint, strict=False)
        print("load fp model finish ...")

    set_MMAdapter(model, adapter_type, dim=adapter_dim, s=adapter_scale,t=temperature)
    set_Clip_Adapter(model.backbone.visual, visual_adapter_type, dim=adapter_dim, s=adapter_scale,t=temperature)


    if use_vicuna:
        apply_model_delta_online(model,'../data/weights/vicuna_'+llm_model)
        
    state_dict={}
    for key in adapter_checkpoint['model']:
        state_dict[key.replace('module.','')]=adapter_checkpoint['model'][key]
    model.load_state_dict(state_dict, strict=False)
    model.to(torch.device('cuda'))

    generator = LaVIN_Generator(model, tokenizer)
    print(f"Loaded in {time.time() - start_time:.2f} seconds")
    return generator

def get_pred_idx(prediction, choices, options):
    """
    Get the index (e.g. 2) from the prediction (e.g. 'C')
    """
    if prediction in options[:len(choices)]:
        return options.index(prediction)
    else:
        return random.choice(range(len(choices)))

def main(
    ckpt_dir: str,
    tokenizer_path: str,
    adapter_path: str,
    data_root:str,
    caption_file:str,
    max_seq_len: int,
    max_batch_size: int,
    llm_model:str='7B',
    generation_temperature: float = 0.1,
    top_p: float = 0.75,
    split='val',
    prompt_format='QCM-ALE',
    use_caption=False,
    options=["A", "B", "C", "D", "E"],
    adapter_type='repattn',
    adapter_dim=8,
    adapter_scale=1,
    n_prompt=10,
    hidden_proj=128,
    visual_adapter_type='normal',
    temperature=10.,
    use_vicuna=False,
    bits: str='16bits',
    cpu_load:bool=False,
    lwc: bool=False,
    lora_enable: bool=False,
    quant_resume: str=None,
    scaling_resume: str=None,
    start_layer: int=100000,
):
    args = get_args_parser() 
    args = args.parse_args()
    # define quant config
    args.weight_quant_params = {
        "n_bits": args.wbits,
        "per_channel_axes": [0],
        "symmetric": args.symmetric,
        "dynamic_method": args.w_dynamic_method,
        "group_size": args.group_size,
        "lwc":args.lwc
    }
    args.act_quant_params = {
        "n_bits":  args.abits,
        "per_channel_axes": [],
        "symmetric": False,
        "dynamic_method": args.a_dynamic_method,
    }


    print(max_batch_size,max_seq_len)
    print('use caption: ',use_caption)
    local_rank, world_size = setup_model_parallel()
    if local_rank > 0:
        sys.stdout = open(os.devnull, "w")

    generator = load(
        ckpt_dir,llm_model, tokenizer_path, adapter_path, local_rank, world_size, max_seq_len, max_batch_size,
        adapter_type,adapter_dim,adapter_scale,hidden_proj,visual_adapter_type,
    temperature,use_vicuna,bits=bits,cpu_load=cpu_load,args = args)

    print('split: ', split)
    problems = json.load(open(os.path.join(data_root, 'problems.json')))
    pid_splits = json.load(open(os.path.join(data_root, 'pid_splits.json')))
    captions = json.load(open(caption_file))["captions"]
    image_path = os.path.join(data_root, 'images', split)
    qids = pid_splits['%s' % (split)]
    total_items=len(qids)
    for qid in problems:
        problems[qid]['caption'] = captions[qid] if qid in captions else ""
    print('total_items: ',total_items)


    image_transforms=transforms.Compose([transforms.Resize((224, 224), interpolation=Image.BICUBIC),transforms.ToTensor(), transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)])

    prompt_args=PromptArgs()
    prompt_args.prompt_format = prompt_format
    prompt_args.use_caption = use_caption
    prompt_args.options = options

    pattern = re.compile(r'The answer is ([A-Z]).')

    answers = []
    preds=[]
    for i in range(total_items//max_batch_size+1):
        print('progresses: ',i,' / ', total_items//max_batch_size+1)
        batch_qids=qids[i*max_batch_size:(i+1)*max_batch_size]
        if len(batch_qids)==0:
            break
        indicators = []
        prompts=[]
        images = []
        for qid in batch_qids:
            prompt,_ = build_prompt(problems, qid, prompt_args)

            answer=problems[qid]["answer"]
            if problems[qid]['image'] is not None:
                image = Image.open(os.path.join(image_path, qid, 'image.png')).convert('RGB')
                image = image_transforms(image)
                indicator = 1
            else:
                image = torch.Tensor(torch.zeros(3, 224, 224).float())
                indicator = 0
            prompts.append(prompt)
            answers.append(answer)
            images.append(image.unsqueeze(0))
            indicators.append(indicator)
        images=torch.cat(images,0)


        results = generator.generate(
            prompts,images=images,indicators=indicators, max_gen_len=64, temperature=generation_temperature, top_p=top_p,n_feats=n_prompt
        )

        for result in results:
            pred = pattern.findall(result)

            if len(pred) >= 1:
                pred = pred[0]  # 'A', 'B', ...
            else:
                print(result)
                pred = "FAILED"
                print(pred)
            preds.append(pred)

    #evaluations
    results={}
    correct=0
    for i, prediction in enumerate(preds):
        pred_idx = get_pred_idx(prediction, problems[qids[i]]["choices"],
                                prompt_args.options)  # 0, 1, ..., 4
        if pred_idx == answers[i]:
            correct += 1
        results[qids[i]] = pred_idx
    acc = correct / len(results) * 100
    print('overall accuracy: ', acc)

    with open('./preds.json', 'w') as f:
        json.dump(results,f)

    scores=get_scores('./preds.json',os.path.join(data_root, 'problems.json'))
    print(scores)
    import time
    with open(str(time.time())+'.txt','w') as f:
        f.write(str(scores))
        f.write(str(args.quant_resume))
        f.write(str(args.adapter_path))

if __name__ == "__main__":
    fire.Fire(main)
