from tqdm import tqdm
import torch
from transformers import LlamaTokenizer, AutoTokenizer
from model_factory import create_model, AutoModelForCausalLM
from config import ShareConfig, add_args
from prepare_data import prepare_data
import torch.nn as nn
from typing import Optional
import os
from huggingface_hub import login
from transformers import DataCollatorWithPadding, DataCollatorForLanguageModeling



os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"

def inspect_state_dict(model: nn.Module, output_file_path: Optional[str] = None):
    report_lines = []
    
    header1 = "\n" + "="*80
    header2 = "Inspecting model's state_dict..."
    
    report_lines.append(header1)
    report_lines.append(header2)

    try:
        state_dict = model.state_dict()
        total_size_mb = 0
        
        report_lines.append("-" * 80)
        report_lines.append(f"{'Layer':<80} | {'Shape':<25} | {'Dtype':<15} | {'Size (MB)':<10}")
        report_lines.append("-" * 80)

        for name, tensor in state_dict.items():
            if tensor is None:
                line = f"- {name:<80} | {'None':<25} | {'N/A':<15} | {'0.00':<10}"
                report_lines.append(line)
                continue

            size_mb = (tensor.nelement() * tensor.element_size()) / 1024 / 1024
            total_size_mb += size_mb
            line = f"- {name:<80} | {str(tensor.shape):<25} | {str(tensor.dtype):<15} | {size_mb:<10.2f}"
            report_lines.append(line)

        footer1 = f"\nCalculated Total Model Size: {total_size_mb / 1024:.2f} GB"
        footer2 = "="*80 + "\n"
        report_lines.append(footer1)
        report_lines.append(footer2)

    except Exception as e:
        error_message = f"An error occurred during inspection: {e}"
        report_lines.append(error_message)

    if output_file_path:
        try:
            with open(output_file_path, 'w', encoding='utf-8') as f:
                for line in report_lines:
                    f.write(line + '\n')
            print(f"Model state_dict report saved to: {output_file_path}")
        except Exception as e:
            print(f"Error saving report to file: {e}")
            for line in report_lines:
                print(line)
    else:
        for line in report_lines:
            print(line)



def compute_ppl(max_length, stride, data, model, device):
    model.to(device)
    model = model.eval()
    seq_len = data.input_ids.size(1)

    nlls = []
    prev_end_loc = 0
    for begin_loc in tqdm(range(0, seq_len, stride)):
        end_loc = min(begin_loc + max_length, seq_len)
        trg_len = end_loc - prev_end_loc
        input_ids = data.input_ids[:, begin_loc:end_loc].to(device)
        target_ids = input_ids.clone()
        target_ids[:, :-trg_len] = -100

        with torch.no_grad():
            output = model(input_ids, labels=target_ids)

            neg_log_likelihood = output.loss
        nlls.append(neg_log_likelihood)
        prev_end_loc = end_loc
        if end_loc == seq_len:
            break
    ppl = torch.exp(torch.stack(nlls).mean())
    return ppl


if __name__ == '__main__':
    cmd_args = add_args()
    config = ShareConfig(cmd_args)
    print(config.compression_ratio)
    if config.model_type == "llama2":
        tokenizer = LlamaTokenizer.from_pretrained(config.model_name, cache_dir="llm_weights")
    else:
        tokenizer = AutoTokenizer.from_pretrained(config.model_name, cache_dir="llm_weights")
    tokenizer.pad_token = "[PAD]"
    train_dataset, val_dataset, test_dataset, data_collator = prepare_data(config.dataset_name, tokenizer,
                                                                           config.context_length,
                                                                           config.dataset_cache_dir)
    print(config.dataset_name)
    model = create_model(config)

    import json

    print(">>> Saving config state AFTER loading with from_pretrained...")
    with open("config_after_load.json", "w") as f:
        json.dump(model.config.to_dict(), f, indent=2, sort_keys=True)
        
    print(">>> Saving weights list AFTER loading with from_pretrained...")
    inspect_state_dict(model, output_file_path="weights_after_load.txt")
    
    print(">>> Inspection finished. Exiting.")

    print(compute_ppl(config.context_length, config.stride, test_dataset, model, 'cuda'))