import os
import json
import yaml
import random
import inspect
from typing import *
from collections import defaultdict

import numpy as np
import munch
import torch
import torch.multiprocessing as mp
from tqdm import tqdm
from torch.utils.data import DataLoader, Dataset

import data
import models
# llms utils
from models.llm_loader import load_model, load_tokenizer
from data.llm_dataloader import load_llm_dataset

def set_seed(seed):
    """
    https://wandb.ai/sauravmaheshkar/RSNA-MICCAI/reports/How-to-Set-Random-Seeds-in-PyTorch-and-Tensorflow--VmlldzoxMDA2MDQy
    """
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ["PYTHONHASHSEED"] = str(seed)
    print(f"Random seed set as {seed}")

def create_mp_process(func, *args, **kwargs):
    p = mp.Process(target=func, args=args, kwargs=kwargs)
    return p

def init_data(config):
    if '@llm' in config.data.name:
        dataset = load_llm_dataset(config)
    else:
        dataset = getattr(data, config.data.name)()
    return dataset

def init_model(config):
    if 'huggingface:' in config.model.type:
        _, pretrained_model_name_or_path = config.model.type.split(':')
        model = load_model(pretrained_model_name_or_path = pretrained_model_name_or_path, fine_tune_config = config.llm.fine_tune_config, token = config.huggingface.token)
        return model
    else:
        model_init_func = getattr(models, config.model.type)
        config.model.pop('type')
        model = model_init_func(**config.model)
        return model

def load_yaml_object(path):
    with open(path, 'r') as infile:
        res = yaml.safe_load(infile)
    res = munch.munchify(res) 
    return res

def dump_yaml_object(path, data):
    with open(path, 'w') as outfile:
        yaml.safe_dump(data, outfile) 

def get_kwargs(func, config):
    kwargs = {}
    for s in inspect.signature(func).parameters.keys():
        if s in config:
            kwargs[s] = config[s]
    return kwargs

def get_dataloader(dataset, *args, **kwargs):
    """
    WARNING: Setting workers may slow down running time :/
    """
    return torch.utils.data.DataLoader(dataset, *args, **kwargs)

def check_loss(model, loader, compute_loss: Callable, device=None):
    model.to(device)
    model.eval()
    total_loss = 0.0 
    num_samples = 0
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device=device)
            y = y.to(device=device)
            logits = model(x)
            loss = compute_loss(logits, y)
            num_samples += y.size(0)
            total_loss += loss.item() * y.size(0)
    total_loss /= num_samples
    return total_loss

def check_accuracy(model, loader, device=None):
    model.to(device)
    model.eval()
    num_correct = 0
    num_samples = 0
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device=device)
            y = y.to(device=device)
            scores = model(x)
            _, predictions = scores.max(1)
            num_correct += (predictions == y).sum()
            num_samples += predictions.size(0)
    return float(num_correct)/float(num_samples)*100

def log_stats(log_path, **kwargs):
    stats = [kwargs]
    if os.path.exists(log_path):
        prev_stats = json.load(open(log_path, 'r'))
        prev_stats.extend(stats)
        stats = prev_stats
    with open(log_path, 'w') as f:
        json.dump(stats, f, indent=2)


def check_accuracy_llm(model, loader, tokenizer, device=None):
    IGNORE_INDEX = -100     # ignore pad tokens
    model.to(device)
    model.eval()  # Set the model to evaluation mode
    num_correct = 0
    num_samples = 0
    print_flag = 0

    with torch.no_grad():
        for batch in tqdm(loader, desc="predicting"):
            input_ids = batch['input_ids'].to(device)
            labels = batch['labels'].to(device)
            
            # Forward pass
            outputs = model(input_ids)
            logits = outputs.logits  # Extract logits from the output
            
            # Get the predictions
            preds = torch.argmax(logits, dim=-1)  # Get the predicted tokens

            shifted_preds = preds[:, :-1]
            shifted_labels = labels[:, 1:]
            # Only consider non-IGNORE_INDEX labels for accuracy
            mask = shifted_labels != IGNORE_INDEX
            masked_preds = shifted_preds[mask]
            masked_labels = shifted_labels[mask]
            
            num_correct += (masked_preds == masked_labels).sum().item()
            num_samples += mask.sum().item()
    
    accuracy = 100 * float(num_correct) / num_samples if num_samples != 0 else 0
    print(f'Got {num_correct} / {num_samples} correct ({accuracy:.2f}%)')
    return accuracy

def check_loss_llm(model, loader, tokenizer, device=None):
    import torch
    import math

    IGNORE_INDEX = -100
    model.to(device)
    model.eval()

    total_loss = 0.0
    total_tokens = 0

    with torch.no_grad():
        for batch in loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            # Forward pass
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss

            if not math.isnan(loss.item()):
                # Calculate number of tokens, excluding padding and ignored indices
                num_tokens = ((labels != IGNORE_INDEX) & (attention_mask == 1)).sum().item()
                
                total_tokens += num_tokens
                total_loss += loss.item() * num_tokens

    avg_loss = total_loss / total_tokens if total_tokens > 0 else 0
    return avg_loss

