import os
import math
import wandb
import torch
import argparse
import numpy as np
import importlib.util
from tqdm import tqdm
from typing import Any
from scipy.stats import spearmanr
from sklearn.metrics import ndcg_score
from torch.utils.data import DataLoader

from .data import RankingDataset, EmbedMapper
from .model import get_model, get_loss_func, get_optimizer


def load_args():
    parser = argparse.ArgumentParser(description="Load config and print seq_path variable")
    parser.add_argument("--cfg_path", type=str, required=True, help="Path to the config file")
    args = parser.parse_args()
    return args


def import_config(config_path):
    spec = importlib.util.spec_from_file_location("config", config_path)
    assert spec is not None and spec.loader is not None, f"Failed to load config from {config_path}"
    config = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(config)
    return config


def load_config(config_path: Any = None):
    if config_path is None:
        args = load_args()
        config_path = args.cfg_path

    if not os.path.exists(config_path):
        raise FileNotFoundError(f"Config file {config_path} does not exist.")
    
    config = import_config(config_path)
    return config


def set_cfg_value(cfg, key, value):
    if key == 'head_model_type':
        cfg.head_model_type = value
    elif key == 'learning_rate':
        cfg.learning_rate = value
    elif key == 'batch_size':
        cfg.batch_size = value
    elif key == 'optimizer_type':
        cfg.optimizer_type = value
    elif key == 'hidden_size':
        cfg.hidden_size = value
    elif key == 'shuffle':
        cfg.shuffle = value
    elif key == 'k':
        cfg.k = value
    else:
        return False

    return True


def single_mrr(predicted, actual):
    target_value = 0
    index_in_actual = np.where(actual == target_value)[0]
    predicted_value = predicted[index_in_actual]
    predicted_rank = 0

    for i in range(len(predicted)):
        if predicted[i] <= predicted_value:
            predicted_rank += 1

    return 1 / predicted_rank


def mrr_score(actual_batch, predicted_batch):
    batch_size = predicted_batch.shape[0]
    total_mrr = 0.

    for i in range(batch_size):
        total_mrr += single_mrr(predicted_batch[i], actual_batch[i])

    return total_mrr / batch_size


def sp_score(actual_batch, predicted_batch):
    batch_size = predicted_batch.shape[0]
    total_sp = 0.

    for i in range(batch_size):
        spc, p = spearmanr(predicted_batch[i], actual_batch[i])
        if math.isnan(spc) or np.isnan(spc):
            spc = np.float64(0.0)
        total_sp += spc

    return total_sp / batch_size


def average_precision(predicted, actual):
    correct = 0
    precision_at_i = []

    for i in range(len(predicted)):
        if predicted[i] in actual[:i+1]:
            correct += 1
            precision_at_i.append(correct / (i + 1))
    
    if len(precision_at_i) == 0:
        return 0
    return np.mean(precision_at_i)


def rank_elements(predicted):
    sorted_indices = np.argsort(predicted)
    ranks = np.empty_like(sorted_indices)
    ranks[sorted_indices] = np.arange(len(predicted))
    return ranks


def map_score(actual_batch, predicted_batch):
    batch_size = predicted_batch.shape[0]
    total_ap = 0.

    for i in range(batch_size):
        predicted = rank_elements(predicted_batch[i])
        # actual = rank_elements(actual_batch[i])
        actual = actual_batch[i]
        total_ap += average_precision(predicted, actual)

    return total_ap / batch_size


def get_param(model):
    if model.backbone_name == 'LucaOne':
        return list(model.backbone.model['model'].parameters())
    else:
        return list(model.backbone.model.parameters())


def train_and_test(cfg):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    is_linear_probe = not hasattr(cfg, 'backbone')
    
    testset = RankingDataset(cfg.folder, cfg.k, test_mode=True)
    trainset = RankingDataset(cfg.folder, cfg.k, test_mode=False, shuffle=cfg.shuffle)
    testloader = DataLoader(testset, batch_size=cfg.test_batch_size, shuffle=False)
    trainloader = DataLoader(trainset, batch_size=cfg.batch_size, shuffle=True)
    if is_linear_probe:
        embed_mapper = EmbedMapper(cfg.folder_embed, cfg.embed_name, cfg.seq_type)
    
    model = get_model(cfg.head_model_type, device, cfg)
    loss_func = get_loss_func(cfg.head_model_type)
    
    if is_linear_probe:
        optimizer = get_optimizer(cfg.optimizer_type, model.parameters(), lr=cfg.learning_rate)
    else:
        optimizer = get_optimizer(cfg.optimizer_type, 
            list(model.parameters()) + get_param(model), lr=cfg.learning_rate)
        # Ensure all parameters of backbones are being optimized
        for group in optimizer.param_groups:
            for param in group['params']:
                print(f"Parameter: {param.size()}, requires_grad={param.requires_grad}")
    
    print("Training...")
    for epoch in range(cfg.num_epochs):
        if epoch == 0:
            print("Evaluating before training...")
            model.eval()
            with torch.no_grad():
                total_score = 0.0
                total_mrr_score = 0.0
                total_sp_score = 0.0
                for test_data, test_labels in tqdm(testloader):
                    test_ranks = torch.stack(test_labels).T.to(device)
                    if is_linear_probe:
                        test_embed = embed_mapper[test_data].to(device)
                    else:
                        test_embed = test_data
                    predicted_scores = model(test_embed)
                    eval_ndcg_score = ndcg_score(test_ranks.to('cpu').numpy(), predicted_scores.to('cpu').numpy())
                    eval_mrr_score = mrr_score(test_ranks.to('cpu').numpy(), predicted_scores.to('cpu').numpy())
                    eval_sp_score = sp_score(test_ranks.to('cpu').numpy(), predicted_scores.to('cpu').numpy())

                    total_score += eval_ndcg_score
                    total_mrr_score += eval_mrr_score
                    total_sp_score += eval_sp_score

            print(f"Before training, Test NDCG Score: {total_score/len(testloader)}, Test MRR Score: {total_mrr_score/len(testloader)}, Test SP Score: {total_sp_score/len(testloader)}")
            if cfg.use_wandb:
                wandb.log({'test NDCG score': total_score/len(testloader)})
                wandb.log({'test MRR score': total_mrr_score/len(testloader)})
                wandb.log({'test SP score': total_sp_score/len(testloader)})
        
        epoch_loss = 0.0
        model.train()
        for data, labels in tqdm(trainloader, desc=f'Epoch {epoch+1}/{cfg.num_epochs}'):
            # data/label: shape = [k, batch_size]
            if is_linear_probe:
                embed = embed_mapper[data].to(device)
            else:
                embed = data
            ranks = torch.stack(labels).T.to(device)
            optimizer.zero_grad()
            predicted_scores = model(embed)
            # loss = criterion(predicted_scores, ranks)
            loss = loss_func(predicted_scores, ranks)
            if cfg.use_wandb:
                wandb.log({'train loss': loss.item()})
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        print(f"Epoch {epoch+1}/{cfg.num_epochs}, Train Loss: {epoch_loss/len(trainloader)}")
        
        if cfg.eval_each_epoch or (epoch == (cfg.num_epochs - 1)):
            print("Evaluating...")
            model.eval()
            with torch.no_grad():
                total_score = 0.0
                total_mrr_score = 0.0
                total_sp_score = 0.0
                for test_data, test_labels in tqdm(testloader):
                    test_ranks = torch.stack(test_labels).T.to(device)
                    if is_linear_probe:
                        test_embed = embed_mapper[test_data].to(device)
                    else:
                        test_embed = test_data
                    predicted_scores = model(test_embed)
                    eval_ndcg_score = ndcg_score(test_ranks.to('cpu').numpy(), predicted_scores.to('cpu').numpy())
                    eval_mrr_score = mrr_score(test_ranks.to('cpu').numpy(), predicted_scores.to('cpu').numpy())
                    eval_sp_score = sp_score(test_ranks.to('cpu').numpy(), predicted_scores.to('cpu').numpy())

                    total_score += eval_ndcg_score
                    total_mrr_score += eval_mrr_score
                    total_sp_score += eval_sp_score

            print(f"Epoch {epoch+1}/{cfg.num_epochs}, Test NDCG Score: {total_score/len(testloader)}, Test MRR Score: {total_mrr_score/len(testloader)}, Test SP Score: {total_sp_score/len(testloader)}")
            if cfg.use_wandb:
                wandb.log({'test NDCG score': total_score/len(testloader)})
                wandb.log({'test MRR score': total_mrr_score/len(testloader)})
                wandb.log({'test SP score': total_sp_score/len(testloader)})
                if not cfg.eval_each_epoch:
                    for i in range(4):
                        wandb.log({'test NDCG score': total_score/len(testloader)})
                        wandb.log({'test MRR score': total_mrr_score/len(testloader)})
                        wandb.log({'test SP score': total_sp_score/len(testloader)})
        print("#" * 100)
        print(" ")
