import copy
import csv
import json
from pathlib import Path
from statistics import fmean

from loguru import logger
from tqdm import tqdm
from transformers import (
    AutoConfig, GPT2LMHeadModel, GPTNeoXForCausalLM,
)
from vds_load import NeoLoader
from vds_shared import DEVICE, REPORT_OUTS_DIR
from vds_util import format_score

from utils.dataset import load_dataset
from utils.template import make_prompt


def load_model(model_name):
    config = AutoConfig.from_pretrained(model_name, output_scores=True)
    # config = AutoConfig.from_pretrained(model_name, output_hidden_states=True, output_scores=True)

    if "gpt2" in model_name:
        model = GPT2LMHeadModel.from_pretrained(model_name, config=config)
        print(model)
    elif "pythia" in model_name:
        model = GPTNeoXForCausalLM.from_pretrained(model_name, config=config)
    else:
        raise NotImplementedError
    return model


def dump_model_info(model_tag, model_repr):
    folder = Path('../docs')
    folder.mkdir(parents=True, exist_ok=True)
    with open(folder / f'{model_tag}.txt', 'w') as handle:
        handle.writelines(model_repr)


def main():
    """
    Models                          Parameters  Layers  Dimension
    gpt2                            117M        12      768
    gpt2-medium                     345M        24      1024
    gpt2-large                      762M        36      1280
    gpt2-xl                         1542M       48      1600
    EleutherAI/pythia-70m-deduped   70M         6       512
    EleutherAI/pythia-160m-deduped  160M        12      768
    EleutherAI/pythia-410m-deduped  410M        24      1024
    EleutherAI/pythia-1b-deduped    1.0B        16      2048
    EleutherAI/pythia-1.4b-deduped  1.4B        24      2048
    EleutherAI/pythia-2.8b-deduped  2.8B        32      2560
    EleutherAI/pythia-6.9b-deduped  6.9B        32      4096
    EleutherAI/pythia-12b-deduped   12.0B       36      5120
    """
    gpt2_model_names = ['gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl']
    pythia_model_names = [
        'EleutherAI/pythia-70m-deduped',
        'EleutherAI/pythia-160m-deduped',
        'EleutherAI/pythia-410m-deduped',
        'EleutherAI/pythia-1b-deduped',
        'EleutherAI/pythia-1.4b-deduped',
        'EleutherAI/pythia-2.8b-deduped',
        'EleutherAI/pythia-6.9b-deduped',
        # 'EleutherAI/pythia-12b-deduped',
    ]
    model_names = gpt2_model_names + pythia_model_names
    for model_name in model_names:
        model = load_model(model_name)
        model_tag = model_name.split('/')[-1]
        model_repr = repr(model)
        dump_model_info(model_tag, model_repr)



"""
# 1024
allowed_shots={'agnews': 2, 'mr': 8, 'mrpc': 4, 'sst5': 4, 'subj': 8, 'trec': 8, 'webss': 2}
# 2048
allowed_shots={'agnews': 4, 'mr': 16, 'mrpc': 8, 'sst5': 8, 'subj': 16, 'trec': 16, 'webss': 4}
"""
def check_context_allowed_shots():
    model_names = ['gpt2', 'EleutherAI/pythia-160m-deduped']
    for model_name in model_names:
        max_context_len = 1024 if 'gpt2' in model_name else 2048
        data_codes = ['agnews', 'mr', 'mrpc', 'sst5', 'subj', 'trec', 'webss']
        tokenizer = NeoLoader.load_tokenizer(model_name)

        allowed_shots = dict()
        for data_code in data_codes:
            _train_data, _dev_data = load_dataset(dataset=data_code)
            allowed_shots[data_code] = 1
            for n_train_shot in [2, 4, 8, 16, 32]:
                train_data = copy.deepcopy(_train_data)
                dev_data = copy.deepcopy(_dev_data)
                train_data.subsamplebyshot(n_train_shot)
                # logger.debug(f'{data_code=}, {len(train_data.data)=}')
                # logger.info(f"===== eval on {dev_data.__len__()} dev examples =====")
                prompt_prefix = make_prompt(train_data, data_code, mode='train')
                dev_labels = []
                label2id = dev_data.label2id
                truncation_count = 0
                truncation_limit = dev_data.__len__() * 0.05
                # length_notes = list()
                for ins in tqdm(dev_data.data, total=dev_data.__len__()):
                    dev_labels.append(label2id[ins['label']])
                    prompt = prompt_prefix + make_prompt(ins, data_code, mode='inference')
                    # logger.info(f'{prompt=}')
                    inputs = tokenizer.encode_plus(prompt, return_tensors="pt", padding=True).to(device=DEVICE)
                    length = inputs['input_ids'].shape[1]
                    # length_notes.append(length)
                    if length > max_context_len:
                        truncation_count += 1
                        if truncation_count > truncation_limit:
                            break
                        inputs['input_ids'] = inputs['input_ids'][:, -max_context_len:]
                        inputs['attention_mask'] = inputs['attention_mask'][:, -max_context_len:]
                # logger.info(f'{truncation_count=}, {truncation_limit=}')
                if truncation_count > truncation_limit:
                    break
                else:
                    allowed_shots[data_code] = n_train_shot
        logger.warning(f'{allowed_shots=}')


"""
# GPT2 tokenizer
"corpus_stats="{
   "agnews":{
      "avg_prompt_length":58.465,
      "train_data_size":120000,
      "test_data_size":7600
   },
   "mr":{
      "avg_prompt_length":31.634,
      "train_data_size":8662,
      "test_data_size":2000
   },
   "mrpc":{
      "avg_prompt_length":62.79,
      "train_data_size":4076,
      "test_data_size":1725
   },
   "sst5":{
      "avg_prompt_length":28.724,
      "train_data_size":8544,
      "test_data_size":2210
   },
   "subj":{
      "avg_prompt_length":33.931,
      "train_data_size":8000,
      "test_data_size":2000
   },
   "trec":{
      "avg_prompt_length":16.292,
      "train_data_size":5452,
      "test_data_size":500
   },
   "webss":{
      "avg_prompt_length":29.838,
      "train_data_size":10060,
      "test_data_size":2280
   }
}

# Pythia tokenizer
"corpus_stats="{
   "agnews":{
      "avg_prompt_length":59.882,
      "train_data_size":120000,
      "test_data_size":7600
   },
   "mr":{
      "avg_prompt_length":32.113,
      "train_data_size":8662,
      "test_data_size":2000
   },
   "mrpc":{
      "avg_prompt_length":63.697,
      "train_data_size":4076,
      "test_data_size":1725
   },
   "sst5":{
      "avg_prompt_length":29.201,
      "train_data_size":8544,
      "test_data_size":2210
   },
   "subj":{
      "avg_prompt_length":34.363,
      "train_data_size":8000,
      "test_data_size":2000
   },
   "trec":{
      "avg_prompt_length":16.443,
      "train_data_size":5452,
      "test_data_size":500
   },
   "webss":{
      "avg_prompt_length":30.005,
      "train_data_size":10060,
      "test_data_size":2280
   }
}
"""
def check_corpus_stats():
    model_names = ['gpt2', 'EleutherAI/pythia-160m-deduped']
    for model_name in model_names:
        data_codes = ['agnews', 'mr', 'mrpc', 'sst5', 'subj', 'trec', 'webss']
        tokenizer = NeoLoader.load_tokenizer(model_name)

        corpus_stats = dict()
        for data_code in data_codes:
            train_data, dev_data = load_dataset(dataset=data_code)
            length_notes = list()
            # train
            train_data_size = len(train_data)
            for ins in tqdm(train_data.data, total=train_data_size):
                prompt = make_prompt(ins, data_code, mode='inference')
                inputs = tokenizer.encode_plus(prompt, return_tensors="pt", padding=True).to(device=DEVICE)
                length = inputs['input_ids'].shape[1]
                length_notes.append(length)
            # dev
            test_data_size = len(dev_data)
            for ins in tqdm(dev_data.data, total=test_data_size):
                prompt = make_prompt(ins, data_code, mode='inference')
                inputs = tokenizer.encode_plus(prompt, return_tensors="pt", padding=True).to(device=DEVICE)
                length = inputs['input_ids'].shape[1]
                length_notes.append(length)
            avg_length = fmean(length_notes)
            corpus_stats[data_code] = {
                'avg_prompt_length': format_score(avg_length),
                'train_data_size': train_data_size,
                'test_data_size': test_data_size,
            }
        logger.warning(f'{corpus_stats=}')


def extract_runtime(identifier, log_file):

    from datetime import datetime

    with open(log_file, 'r', encoding="ISO-8859-1") as f:
        lines = f.readlines()

    llm_info = '='
    data_info = '='
    kp_time_flag = False
    sc_cache_time_flag = False
    runtimes_in_hour = list()
    timestamp_format = '%Y-%m-%dT%H:%M:%S'

    for line in lines:

        if 'args=Namespace' in line:
            _, config_str = line.split('args=Namespace')
            data_str, model_str, *_ = config_str.split(', ')
            _, llm_info = model_str.split('=')
            _, data_info = data_str.split('=')

        if identifier in ['lora', 'ia3', 'sc_cluster']:
            runtime_str = None
            if ' 1/1 ' in line and identifier in ['lora', 'ia3']:
                _, info_str = line.split(' 1/1 ')
                runtime_str, _ = info_str.strip('[]').split('<')
            elif ' 100/100 ' in line and identifier in ['sc_cluster']:
                _, info_str = line.split(' 100/100 ')
                runtime_str, _ = info_str.strip('[]').split('<')

            if runtime_str is not None:
                runtime_list = runtime_str.split(':')
                runtime_list = [int(runtime) for runtime in runtime_list]
                while len(runtime_list) < 3:
                    runtime_list = [0] + runtime_list
                runtime_in_sec = runtime_list[0] * 3600 + runtime_list[1] * 60 + runtime_list[2]

        if identifier in ['icl', 'kp', 'lora', 'ia3', 'sc_cache']:
            # do not break the loop, since we need the over-written values
            if kp_time_flag:
                kp_time_flag = False
                timestamp_str, _ = line.split('+0800:')
                mid_timestamp = datetime.strptime(timestamp_str, timestamp_format)
            if ':   warnings.warn(\n' in line:
                timestamp_str, _ = line.split('+0800:')
                start_timestamp = datetime.strptime(timestamp_str, timestamp_format)
                if identifier == 'kp':
                    kp_time_flag = True
                if identifier == 'sc_cache':
                    sc_cache_time_flag = True
            if 'Loading checkpoint shards' in line:
                if identifier == 'sc_cache':
                    sc_cache_time_flag = True
            if ': \n' in line:
                if sc_cache_time_flag:
                    sc_cache_time_flag = False
                    timestamp_str, _ = line.split('+0800:')
                    mid_timestamp = datetime.strptime(timestamp_str, timestamp_format)
                else:
                    timestamp_str, _ = line.split('+0800:')
                    end_timestamp = datetime.strptime(timestamp_str, timestamp_format)

    # ...
    if identifier in ['lora', 'ia3', 'sc_cluster']:
        runtime_in_hour = runtime_in_sec / 3600
        runtimes_in_hour.append(runtime_in_hour)

    if identifier in ['icl', 'kp', 'lora', 'ia3', 'sc_cache']:
        if identifier in ['kp', 'sc_cache']:
            duration_in_sec = (mid_timestamp - start_timestamp).total_seconds()
            duration_in_hour = duration_in_sec / 3600
            runtimes_in_hour.append(duration_in_hour)
            duration_in_sec = (end_timestamp - mid_timestamp).total_seconds()
            duration_in_hour = duration_in_sec / 3600
            runtimes_in_hour.append(duration_in_hour)
        else:
            duration_in_sec = (end_timestamp - start_timestamp).total_seconds()
            duration_in_hour = duration_in_sec / 3600
            if identifier in ['lora', 'ia3']:
                duration_in_hour -= runtime_in_hour
            runtimes_in_hour.append(duration_in_hour)

    return llm_info, data_info, runtimes_in_hour


def report_stats(identifier, llm_filter, *triplet):
    REPORT_OUTS_DIR.mkdir(parents=True, exist_ok=True)
    save_results_file = REPORT_OUTS_DIR / f'runtime_{llm_filter}.csv'
    csv_exists = save_results_file.exists()
    with open(save_results_file, 'a+', newline='') as csvfile:
        csvwriter = csv.writer(csvfile)
        if not csv_exists:
            csvwriter.writerow(['method', 'dataset', 'runtime'])
        llm_info, data_info, runtimes_in_hour = triplet
        # if llm_info in ['gpt2-xl', 'pythia-xl']:
        if llm_info == llm_filter:
            for idx, runtime in enumerate(runtimes_in_hour):
                runtime = format_score(runtime)
                logger.warning(f'{identifier=}_{idx}, {data_info=}, {runtime=}')
                csvwriter.writerow([identifier+f'_{idx}', data_info, runtime])


def process_logs(identifier, llm_filter='llama3-8b'):
    from os import listdir
    log_path = f'../logs_llama3/llama3_{identifier}'
    log_files = listdir(log_path)
    for log_file in log_files:
        if log_file == '.DS_Store':
            continue
        try:
            triplet = extract_runtime(identifier, f'{log_path}/{log_file}')
            report_stats(identifier, llm_filter, *triplet)
        except ValueError as e:
            logger.debug(f'{log_file=}')
            print(e)


def stats_computation_cost():
    process_logs('icl')
    process_logs('kp')
    process_logs('lora')
    process_logs('ia3')
    # process_logs('sc')


if __name__ == '__main__':
    # main()
    # check_context_allowed_shots()
    # check_corpus_stats()
    stats_computation_cost()
