import argparse
import torch
import time
from tqdm import trange
from transformers import AutoModelForCausalLM, AutoTokenizer
from lm_polygraph.stat_calculators.hybrid_sampling import HybridBeamThenConditionalSamplingCalculator
from lm_polygraph.stat_calculators.concat_greedy_with_samples import ConcatGreedyWithSamplesBase
from lm_polygraph.stat_calculators import *
from lm_polygraph.estimators import *
from lm_polygraph.utils.model import WhiteboxModel
from lm_polygraph.utils.deberta import Deberta
from lm_polygraph.utils.generation_parameters import GenerationParametersFactory
from lm_polygraph.stat_calculators.greedy_cross_encoder_similarity import GreedyCrossEncoderSimilarityMatrixCalculatorBase
from lm_polygraph.stat_calculators.semantic_matrix import SemanticMatrixCalculatorBase


def load_model(model_path: str, device_map: str):
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        trust_remote_code=True,
        device_map=device_map,
        attn_implementation="eager",
    )
    if not hasattr(model.config, "num_hidden_layers"):
        model.config.num_hidden_layers = model.config.text_config.num_hidden_layers
    if not hasattr(model.config, "num_attention_heads"):
        model.config.num_attention_heads = model.config.text_config.num_attention_heads
    model.eval()

    return model


def load_tokenizer(model_path: str, add_bos_token: bool = True):
    tokenizer = AutoTokenizer.from_pretrained(
        model_path,
        padding_side="left",
        add_bos_token=add_bos_token,
    )
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    return tokenizer


def load_man(man_path):
    return torch.load(man_path, weights_only=False)


def get_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument('--man-path', type=str, required=True)
    parser.add_argument('--save-path', type=str, required=True)
    parser.add_argument('--model-path', type=str, required=True)
    parser.add_argument('--instruct', type=bool, action=argparse.BooleanOptionalAction, default=False)
    parser.add_argument('--ce-sim', type=bool, action=argparse.BooleanOptionalAction, default=False)
    parser.add_argument('--nli-sim', type=bool, action=argparse.BooleanOptionalAction, default=False)
    parser.add_argument('--device-map', type=str, default="auto")
    parser.add_argument('--num-beams', type=int, nargs='+', default=[1, 2, 3, 4, 5, 6, 7, 8, 9])
    return parser


def main(args):
    man = load_man(args.man_path)

    base_model = load_model(args.model_path, device_map=args.device_map)
    tokenizer = load_tokenizer(args.model_path)
    generation_params = GenerationParametersFactory.from_params(
        yaml_config={"generate_until": ["\n"]},
        native_config=base_model.generation_config.to_dict()
    )
    model = WhiteboxModel(
        base_model,
        tokenizer,
        args.model_path,
        generation_parameters=generation_params,
        instruct=args.instruct,
    )

    for num_beams in args.num_beams:
        print('NUM_BEAMS:', num_beams)
        sample_source = f'beamsearch{num_beams}sample{10 - num_beams}'

        stat_calculators = []
        if f'{sample_source}_texts' not in man['stats'].keys():
            stat_calculators += [
                HybridBeamThenConditionalSamplingCalculator(num_beams=num_beams, num_samples=10),
                ConcatGreedyWithSamplesBase(sample_source=sample_source),
            ]
        if args.ce_sim and f'greedy_{sample_source}_sentence_similarity' not in man['stats'].keys():
            stat_calculators.append(GreedyCrossEncoderSimilarityMatrixCalculatorBase(
                batch_size=100,
                sample_source=sample_source,
                progress=False,
            ))
        if args.nli_sim and f'greedy+{sample_source}_semantic_matrix_entail' not in man['stats'].keys():
            stat_calculators.append(SemanticMatrixCalculatorBase(
                nli_model=Deberta(),
                sample_source='greedy+' + sample_source,
                progress=False,
            ))
        if len(stat_calculators) == 0:
            print('Nothing to calculate')
        for stat_calculator in stat_calculators:
            print(f'Calculating {stat_calculator}')
            start_time = time.time()
            saved_keys, skipped_keys = [], []
            for i in trange(len(man['stats']['input_texts']), desc=str(stat_calculator)):
                batch_stats = {}
                for key in man['stats']:
                    try:
                        batch_stats[key] = man['stats'][key][i:i + 1]
                    except Exception as e:
                        continue

                new_stats = stat_calculator(
                    batch_stats,
                    batch_stats['input_texts'],
                    model,
                )
                for key, val in new_stats.items():
                    try:
                        for v in val:
                            if key not in man['stats'].keys():
                                man['stats'][key] = []
                                saved_keys.append(key)
                            man['stats'][key].append(v)
                    except Exception as e:
                        if key in skipped_keys:
                            continue
                        print(f'Skipping saving key {key}: {e}')
                        skipped_keys.append(key)
            print(f'Will save keys: {saved_keys}')
            print(f'Done calculating {stat_calculator} in {round(time.time() - start_time, 2)} secs')
            print(f'Saving to {args.save_path}')
            torch.save(man, args.save_path)
    print('Done!')


if __name__ == '__main__':
    parser = get_parser()
    args = parser.parse_args()
    main(args)
