import gzip
import json
import math
import re
from collections import defaultdict
from difflib import SequenceMatcher
from functools import partial
from statistics import median, pstdev, fmean

import numpy as np
import pandas as pd
import torch
from datasets import load_dataset
from evaluate import load
from loguru import logger
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from sklearn.model_selection import train_test_split
from transformers import (
    GPT2Tokenizer,
    GPT2LMHeadModel,
    AutoConfig,
    AutoTokenizer,
    BitsAndBytesConfig,
    AutoModelForCausalLM,
    Starcoder2ForCausalLM, CodeGenForCausalLM
)

from me_shared import (
    DRYRUN_TEST_NUM, DATA_DIR,
    CFG_MODE_DRYRUN, GEN_TEMP, GEN_OUTS_DIR
)
from me_util import format_score, format_ratio, format_rate

# import os
# import torch.distributed as dist
# from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
#
# # Initialize distributed training (required)
# dist.init_process_group(backend="nccl")
# local_rank = int(os.environ["LOCAL_RANK"])  # Mandatory for pure FSDP
# torch.cuda.set_device(local_rank)          # Mandatory for pure FSDP


class NeoLoader:
    @staticmethod
    def _load_conala(data_folder, split: str, data_type: str = 'annotated'):
        assert data_type in ('annotated', 'mined')
        if data_type == 'mined':
            anno_lines = list()
            code_lines = list()
            with open(data_folder / 'conala-mined.jsonl', 'r') as file:
                for line in file:
                    datum = json.loads(line.strip())
                    anno_lines.append(datum['intent'])
                    code_lines.append(datum['snippet'])
            anno_lines_train, anno_lines_test, code_lines_train, code_lines_test = train_test_split(
                anno_lines, code_lines, test_size=1000)
            return anno_lines_train, anno_lines_test, code_lines_train, code_lines_test
        else:
            dataset = json.load(open(data_folder / f'conala-{split}.json'))
            anno_lines = list()
            code_lines = list()
            for datum in dataset:
                # in some cases, rewritten_intent = null
                intent, rewritten_intent = datum['intent'], datum['rewritten_intent']
                anno_line = str(rewritten_intent if rewritten_intent else intent)
                anno_line = anno_line.replace('"""', '"')
                anno_lines.append(anno_line)
                code_line = datum['snippet']
                code_lines.append(code_line)
            return anno_lines, code_lines

    @staticmethod
    def _load_ia32(data_folder, split: str):
        with open(data_folder / 'Shellcode_IA32.tsv') as file:
            df = pd.read_csv(file, delimiter='\t')

        # the operation is specified in the dataset
        train = df.sample(frac=0.8, random_state=0)
        test = df.drop(train.index)
        # dev = test.sample(frac=0.5, random_state=0)
        # test = test.drop(dev.index)

        if split == 'train':
            data = train
        # elif split == 'dev':
        #     data = dev
        elif split == 'test':
            data = test
        else:
            raise NotImplementedError

        intents = list()
        snippets = list()
        for idx, row in data.iterrows():
            intents.append(row["INTENTS"])
            snippets.append(row["SNIPPETS"])
        return intents, snippets

    @staticmethod
    def _load_spider(data_folder, split: str):
        if split == 'train':
            dataframe = pd.read_parquet(data_folder / 'spider-train.parquet', columns=['question_toks', 'query_toks'])
            # logger.info(dataframe.head(3))
        elif split == 'test':
            dataframe = pd.read_parquet(data_folder / 'spider-validation.parquet', columns=['question_toks', 'query_toks'])
            # logger.info(dataframe.head(3))
        else:
            raise NotImplementedError
        text_seqs, sql_seqs = dataframe['question_toks'].tolist(), dataframe['query_toks'].tolist()
        text_lines = [' '.join(text_seq) for text_seq in text_seqs]
        sql_lines = [' '.join(sql_seq) for sql_seq in sql_seqs]
        # logger.info(f'{text_lines[0]=}')
        # logger.info(f'{sql_lines[0]=}')
        # logger.info(f'{len(text_lines)=}')
        # logger.info(f'{len(sql_lines)=}')
        return text_lines, sql_lines

    @staticmethod
    def _load_tldr(data_folder, split: str):
        with open(data_folder / f'tldr-{split}.jsonl') as file:
            data_lines = file.readlines()
        nl_lines = list()
        cmd_lines = list()
        for line in data_lines:
            instance = json.loads(line.strip())
            nl_lines.append(instance['nl'])
            cmd_lines.append(instance['cmd'])
        return nl_lines, cmd_lines

    @staticmethod
    def load_data(data_name):
        # load data
        data_folder = DATA_DIR / data_name
        if data_name == 'conala':
            train_text_seqs, train_code_seqs = NeoLoader._load_conala(data_folder, split='train')
            test_text_seqs, test_code_seqs = NeoLoader._load_conala(data_folder, split='test')
        elif data_name == 'ia32':
            train_text_seqs, train_code_seqs = NeoLoader._load_ia32(data_folder, split='train')
            test_text_seqs, test_code_seqs = NeoLoader._load_ia32(data_folder, split='test')
        elif data_name == 'spider':
            train_text_seqs, train_code_seqs = NeoLoader._load_spider(data_folder, split='train')
            test_text_seqs, test_code_seqs = NeoLoader._load_spider(data_folder, split='test')
        elif data_name == 'tldr':
            train_text_seqs, train_code_seqs = NeoLoader._load_tldr(data_folder, split='train')
            # eval_text_seqs, eval_code_seqs = NeoLoader._load_tldr(data_folder, split='dev')
            test_text_seqs, test_code_seqs = NeoLoader._load_tldr(data_folder, split='test')
            # train_text_seqs += eval_text_seqs
            # train_code_seqs += eval_code_seqs
        else:
            raise NotImplementedError
        if CFG_MODE_DRYRUN:
            train_text_seqs = train_text_seqs[:DRYRUN_TEST_NUM * 10]
            train_code_seqs = train_code_seqs[:DRYRUN_TEST_NUM * 10]
            test_text_seqs = test_text_seqs[:DRYRUN_TEST_NUM]
            test_code_seqs = test_code_seqs[:DRYRUN_TEST_NUM]
        return train_text_seqs, train_code_seqs, test_text_seqs, test_code_seqs

    @staticmethod
    def load_corpus(data_name):
        # in the future, we may turn to use EvalPlus (HumanEval+, MBPP+, BigcodeBench+, ...)
        if data_name == 'demo':
            text_data = ['The countries of the European Union are:\n1. Austria\n2. Belgium\n3. Bulgaria\n4.']
            code_data = ['Denmark']
            # allow_vocab_labels = ['Denmark']
        elif data_name == 'bigcode':
            text_data, code_data = _load_bcbh()
        elif data_name == 'human':
            text_data, code_data = _load_he()
            # allow_vocab_labels = NeoLoader.load_allow_vocab(data_name, chunk_id, model_name, solutions)
        elif data_name in ['numpy', 'pandas']:
            text_data, code_data = _load_cert(lib=data_name)
            # allow_vocab_labels = NeoLoader.load_allow_vocab(data_name, chunk_id, model_name, solutions)
        else:
            raise NotImplementedError
        if CFG_MODE_DRYRUN:
            text_data = text_data[:DRYRUN_TEST_NUM]
            code_data = code_data[:DRYRUN_TEST_NUM]

        input_lens = [len(text) for text in text_data]
        output_lens = [len(code) for code in code_data]
        avg_input_lens = fmean(input_lens)
        avg_output_lens = fmean(output_lens)
        print(f'{avg_input_lens=}, {avg_output_lens=}')
        return text_data, code_data

    @staticmethod
    def load_tokenizer(model_name: str):
        if "gpt2" in model_name:
            tokenizer = GPT2Tokenizer.from_pretrained(model_name)
            tokenizer.padding_side = "left"
            tokenizer.pad_token = tokenizer.eos_token
            tokenizer.pad_token_id = tokenizer.eos_token_id
        elif "starcoder" in model_name:
            tokenizer = AutoTokenizer.from_pretrained(model_name)
        elif "CodeLlama" in model_name or "Llama" in model_name:
            tokenizer = AutoTokenizer.from_pretrained(model_name)
            tokenizer.padding_side = "left"
            tokenizer.pad_token = tokenizer.eos_token
            tokenizer.pad_token_id = tokenizer.eos_token_id
        elif "OpenCoder" in model_name:
            tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
        elif "Qwen3" in model_name or "Qwen2.5-Coder" in model_name:
            tokenizer = AutoTokenizer.from_pretrained(model_name)
        else:
            raise ValueError(f"Model {model_name} not supported")
        return tokenizer

    @staticmethod
    def load_model(model_name: str):
        config = AutoConfig.from_pretrained(model_name, output_scores=True)
        # config = AutoConfig.from_pretrained(
        #     model_name,
        #     device_map='auto',
        #     output_scores=True,
        #     output_hidden_states=True,
        #     # trust_remote_code=True,
        #     load_in_4bit=(DEVICE == 'cuda')
        # )
        attrs = dict()
        if "gpt2" in model_name:
            attrs['layers'] = 'transformer.h'
            attrs['ff_act'] = 'mlp.act'
            attrs['ff_gate'] = '...'
            attrs['ff_input'] = 'mlp.c_fc'
            attrs['ff_output'] = 'mlp.c_proj'
            attrs['embedding'] = 'transformer.wte'
            attrs['lm_head'] = 'lm_head'
            model = GPT2LMHeadModel.from_pretrained(model_name, config=config)
        elif "starcoder" in model_name:
            attrs['layers'] = 'model.layers'
            attrs['ff_act'] = 'mlp.act'
            attrs['ff_gate'] = '...'
            attrs['ff_input'] = 'mlp.c_fc'
            attrs['ff_output'] = 'mlp.c_proj'
            attrs['embedding'] = 'model.embed_tokens'
            attrs['lm_head'] = 'lm_head'
            model = Starcoder2ForCausalLM.from_pretrained(model_name, config=config)
        elif "OpenCoder" in model_name:
            attrs['layers'] = 'model.layers'
            attrs['ff_act'] = 'mlp.act_fn'
            attrs['ff_gate'] = '...'
            attrs['ff_input'] = 'mlp.up_proj'
            attrs['ff_output'] = 'mlp.down_proj'
            attrs['embedding'] = 'model.embed_tokens'
            attrs['lm_head'] = 'lm_head'
            model = AutoModelForCausalLM.from_pretrained(model_name, config=config, trust_remote_code=True)
        elif "CodeLlama" in model_name or "Llama" in model_name:
            attrs['layers'] = 'model.layers'
            attrs['ff_act'] = 'mlp.act_fn'
            attrs['ff_gate'] = 'mlp.gate_proj'
            attrs['ff_input'] = 'mlp.up_proj'
            attrs['ff_output'] = 'mlp.down_proj'
            attrs['embedding'] = 'model.embed_tokens'
            attrs['lm_head'] = 'lm_head'
            model = AutoModelForCausalLM.from_pretrained(model_name, config=config, attn_implementation="eager")
        elif "Qwen3" in model_name or "Qwen2.5-Coder" in model_name:
            attrs['layers'] = 'model.layers'
            attrs['ff_act'] = 'mlp.act_fn'
            attrs['ff_gate'] = 'mlp.gate_proj'
            attrs['ff_input'] = 'mlp.up_proj'
            attrs['ff_output'] = 'mlp.down_proj'
            attrs['embedding'] = 'model.embed_tokens'
            attrs['lm_head'] = 'lm_head'
            # we study scaling with Qwen-Coder models, which are loaded in BF16 to save memory
            model = AutoModelForCausalLM.from_pretrained(model_name, config=config)
            # model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, config=config)
            # if '32B' in model_name:
            #     model = FSDP(model, sharding_strategy=torch.distributed.fsdp.ShardingStrategy.FULL_SHARD)
            #     model.gradient_checkpointing_enable()
        else:
            raise ValueError(f"Model {model_name} not supported")
        model.eval()
        # disable gradients
        # for param in model.parameters():
        #     param.requires_grad = False
        return config, model, attrs


def _load_bcbh():
    dataset = load_dataset('bigcode/bigcodebench-hard', split='v0.1.4', streaming=True)

    prompts = []
    solutions = []
    for datum in dataset:
        # prompt = datum["complete_prompt"]  # for base LMs
        # prompt = datum["instruct_prompt"]  # for chat LMs
        prompt = datum["code_prompt"]
        solution = datum["canonical_solution"]
        solution = solution.replace(' ' * 4, '\t')
        prompts.append(prompt)
        solutions.append(solution)
    return prompts, solutions


def _load_he():
    dataset = load_dataset('openai_humaneval', split='test')

    prompts = []
    solutions = []
    # entry_points = []
    # test_funcs = []
    for datum in dataset:
        prompt = datum["prompt"]
        solution = datum["canonical_solution"]
        solution = solution.replace(' ' * 4, '\t')
        # entry_point = datum["entry_point"]
        # test_func = datum["test"]
        prompts.append(prompt)
        solutions.append(solution)
        # entry_points.append(entry_point)
        # test_funcs.append(test_func)
    return prompts, solutions


def _load_cert(lib: str):
    assert lib in ('numpy', 'pandas')
    data_dir = DATA_DIR / 'cert'

    prompts = []
    solutions = []
    # entry_points = []
    # test_funcs = []
    with open(data_dir / f'{lib.title()}Eval.jsonl.gz', 'rb') as file:
        with gzip.open(file, 'rt') as fp:
            for line in fp:
                if any(not x.isspace() for x in line):
                    # yield json.loads(line)
                    json_str = json.loads(line)
                    prompt = json_str["prompt"]
                    solution = json_str["canonical_solution"][0]
                    solution = solution.replace(' ' * 4, '\t')
                    # entry_point = json_str["entry_point"]
                    # test_func = json_str["test"]
                    prompts.append(prompt)
                    solutions.append(solution)
                    # entry_points.append(entry_point)
                    # test_funcs.append(test_func)
    return prompts, solutions


def well_load(func, folder, filename):
    import pickle
    if (folder / filename).is_file():
        with open(folder / filename, 'rb') as handle:
            data = pickle.load(handle)
    else:
        data = func()
        folder.mkdir(parents=True, exist_ok=True)
        with open(folder / filename, 'wb') as handle:
            pickle.dump(data, handle)
    return data


def well_load_gens(data_name, model_name):
    folder = GEN_OUTS_DIR / f'{data_name}_{model_name}'
    assert folder.exists()
    # folder.mkdir(parents=True, exist_ok=True)
    gen_marks = ['pre_gens', 'post_gens', 'oracle_gens']
    gens_res = list()
    logger.info("load generations")
    for gen_mark in gen_marks:
        filename = GEN_TEMP.format(data_name=data_name, gen_mark=gen_mark)
        with open(folder / filename, 'r') as file:
            lines = file.readlines()
        gens = list()
        for line in lines:
            gen = json.loads(line)
            gens.append(gen)
        gens_res.append(gens)
    pre_gens, post_gens, oracle_gens = gens_res
    return pre_gens, post_gens, oracle_gens


def well_dump_gens(data_name, model_name, pre_gens, post_gens, oracle_gens):
    folder = GEN_OUTS_DIR / f'{data_name}_{model_name}'
    folder.mkdir(parents=True, exist_ok=True)
    gen_marks = ['pre_gens', 'post_gens', 'oracle_gens']
    gens_res = [pre_gens, post_gens, oracle_gens]
    logger.info("save generations")
    for gen_mark, gens in zip(gen_marks, gens_res):
        filename = GEN_TEMP.format(data_name=data_name, gen_mark=gen_mark)
        with open(folder / filename, 'w') as file:
            for gen in gens:
                file.write(json.dumps(gen) + '\n')


class Metric:
    @staticmethod
    def longest_match(gens: list[list[str]], refs: list[list[str]]):
        """Longest Common Substring on the token-level"""
        scores = list()
        for gen, ref in zip(gens, refs):
            gen, ref = tuple(gen), tuple(ref)
            matcher = SequenceMatcher(a=gen, b=ref, autojunk=False)
            match = matcher.find_longest_match()
            score = match.size / len(ref)
            scores.append(score)
        avg_score = np.average(scores)
        return format_score(avg_score)

    @staticmethod
    def exact_matching(gens: list[list[str]], refs: list[list[str]]):
        """Exact Match on the token-level"""
        scores = list()
        metric = load('exact_match')
        # print(metric.inputs_description)
        for gen, ref in zip(gens, refs):
            score = metric.compute(predictions=gen, references=ref)['exact_match']
            scores.append(score)
        avg_score = np.average(scores)
        return format_score(avg_score)

    @staticmethod
    def bleu_score(predictions: list[str], references: list[str]):
        """BLEU on the char-level"""
        predictions = [prediction for prediction in predictions]
        references = [[reference] for reference in references]
        metric = load('bleu')
        # print(metric.inputs_description)
        score = metric.compute(predictions=predictions, references=references)['bleu']
        return format_score(score)

    @staticmethod
    def meteor_score(predictions: list[str], references: list[str]):
        """METEOR-1.0 on the char-level"""
        predictions = [prediction for prediction in predictions]
        references = [reference for reference in references]
        metric = load('meteor')
        # print(metric.inputs_description)
        score = metric.compute(predictions=predictions, references=references)['meteor']
        return format_score(score)

    @staticmethod
    def rouge_score(predictions: list[str], references: list[str]):
        """ROUGE on the char-level"""
        predictions = [prediction for prediction in predictions]
        references = [reference for reference in references]
        metric = load('rouge')
        # print(metric.inputs_description)
        score = metric.compute(predictions=predictions, references=references)['rougeL']
        return format_score(score)

    @staticmethod
    def edit_similarity(gens: list[str], refs: list[str]):
        """Edit Similarity on the char-level"""
        from nltk import edit_distance
        scores = list()
        for gen, ref in zip(gens, refs):
            score = 1 - edit_distance(gen, ref) / max(len(gen), len(ref))
            scores.append(score)
        avg_score = np.average(scores)
        return format_score(avg_score)

    @staticmethod
    def exact_match_score(predictions: list[str], references: list[str]):
        """Exact Match on the char-level"""
        predictions = [prediction for prediction in predictions]
        references = [reference for reference in references]
        metric = load('exact_match')
        # print(metric.inputs_description)
        score = metric.compute(predictions=predictions, references=references)['exact_match']
        return format_score(score)

    @staticmethod
    def gestalt_match_score(gens: list[str], refs: list[str]):
        """Gestalt Pattern Matching on the token-level"""
        # https://github.com/python/cpython/blob/main/Lib/difflib.py#L597
        scores = list()
        for gen, ref in zip(gens, refs):
            gen, ref = tuple(gen), tuple(ref)
            matcher = SequenceMatcher(a=gen, b=ref, autojunk=False)
            score = matcher.ratio()
            scores.append(score)
        avg_score = np.average(scores)
        return format_score(avg_score)

    @staticmethod
    def contrast_scoring(pre_gens, post_gens, oracle_gens):
        # from collections import Counter
        # pre_cnt = Counter()
        # post_cnt = Counter()
        # for pre_seq, post_seq, oracle_seq in zip(pre_gens, post_gens, oracle_gens):
        #     for pre_token, post_token, oracle_token in zip(pre_seq, post_seq, oracle_seq):
        #         pre_cnt[pre_token == oracle_token] += 1
        #         post_cnt[post_token == oracle_token] += 1
        # logger.info(f'{pre_cnt=}')
        # logger.info(f'{post_cnt=}')
        # logger.info(f'{post_cnt[True] - pre_cnt[True]=}')
        # logger.info(f'{post_cnt[True] + post_cnt[False]=}')
        # ...
        for func in (
            Metric.exact_matching,
            # Metric.gestalt_matching,
        ):
            pre_score = func(pre_gens, oracle_gens)
            post_score = func(post_gens, oracle_gens)
            ratio = format_ratio(pre_score, post_score)
            logger.success(f'{func.__name__}: {pre_score=}, {post_score=} ({ratio=})')

    @staticmethod
    def contrast_scoring2(pre_gens, post_gens, oracle_gens):
        pre_gens = [' '.join(gen) for gen in pre_gens]
        post_gens = [' '.join(gen) for gen in post_gens]
        oracle_gens = [' '.join(gen) for gen in oracle_gens]
        for func in (
            Metric.edit_similarity,
            # Metric.exact_match_score,
            # Metric.gestalt_match_score,
            Metric.bleu_score,
            # Metric.rouge_score,  # ROUGE is less recommended than BLEU
            # Metric.meteor_score,
        ):
            pre_score = func(pre_gens, oracle_gens)
            post_score = func(post_gens, oracle_gens)
            ratio = format_ratio(pre_score, post_score)
            logger.success(f'{func.__name__}: {pre_score=}, {post_score=} ({ratio=})')

    @staticmethod
    def same_accuracy(gens: list[int], refs: list[int]):
        dev_correct = [1 if refs[i] == gens[i] else 0 for i in range(len(refs))]
        acc_score = sum(dev_correct) / len(refs)
        return acc_score

    @staticmethod
    def compute_statistics(gens_rank: list[float], refs_rank: list[float], window_size=None):
        if window_size is not None:
            gens_rank = [min(gen_rank, window_size) for gen_rank in gens_rank]
            refs_rank = [min(ref_rank, window_size) for ref_rank in refs_rank]
        # format_score(median(data)), format_score(fmean(data)), format_score(pstdev(data))
        mae = mean_absolute_error(refs_rank, gens_rank)
        # mse = mean_squared_error(refs_rank, gens_rank)
        rmse = mean_squared_error(refs_rank, gens_rank, squared=False)
        # r2 = r2_score(refs, gens)
        return format_score(mae), format_score(rmse)
