import os
import random
import pickle
import json
import wandb

import numpy as np
import torch
import torch.nn as nn
from torch import optim
import torch.backends.cudnn as cudnn
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from rdkit import Chem
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score, accuracy_score, f1_score, roc_auc_score, precision_score, recall_score
from scipy.stats import pearsonr, spearmanr



class SmilesDataset(Dataset):
    def __init__(self, smiles, target):
        self.smiles = smiles
        self.target = target

    def __len__(self):
        return len(self.smiles)    
    
    def __getitem__(self, idx):
        smiles = self.smiles[idx]
        labels = self.target[idx]
        return smiles, labels
    
    
class Trainer:
    def __init__(self, args, cfg, model, optimizer, loss_fn):
        self.args = args
        self.cfg = cfg
        
        self._set_seed(args.seed)
        self.device = args.device
        
        self.dataset_name = args.prop_type
         

        self.batch_size = cfg.exp.batch_size
        self.best_vloss = float('inf')
        
        self.model = model
        self.optimizer = optimizer
        self.loss_fn = loss_fn
        
        self.hparams = cfg
        self.seed = args.seed
         
        self._prepare_data()
        
    
    def normalize_smiles(self, smi, canonical=True, isomeric=False):
        try:
            normalized = Chem.MolToSmiles(
            Chem.MolFromSmiles(smi), canonical=canonical, isomericSmiles=isomeric
            )
        except:
            normalized = None
        return normalized
    
    def load_dataset(self, split):
        data_path = os.path.join(self.args.data_dir, self.args.dataset_name, self.args.dataset_split_type, self.args.prop_type, f"{self.args.prop_type}.pkl")
        all_dataset = pickle.load(open(data_path, 'rb'))
        
        smiles = all_dataset[split]['smiles']
        targets = all_dataset[split]['targets']

        canon_smiles = []
        valid_targets = []
        for smi, target in zip(smiles, targets):
            canon_smi = self.normalize_smiles(smi)
            if canon_smi:
                canon_smiles.append(canon_smi)
                valid_targets.append(target)
            else:
                continue

        return canon_smiles, targets

    def _prepare_data(self):
        smiles, targets = self.load_dataset('train')
        train_dataset = SmilesDataset(smiles, targets)
        self.train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, pin_memory=True)
        
        smiles, targets = self.load_dataset('eval')
        eval_dataset = SmilesDataset(smiles, targets)
        self.eval_loader = DataLoader(eval_dataset, batch_size=self.batch_size, shuffle=False, pin_memory=True)
        
        smiles, targets = self.load_dataset('ood')
        ood_dataset = SmilesDataset(smiles, targets)
        self.ood_loader = DataLoader(ood_dataset, batch_size=self.batch_size, shuffle=False, pin_memory=True)
        print('eval', (np.array(eval_dataset.target) == 0).sum())
        print('ood', (np.array(ood_dataset.target) == 0).sum())

    def fit(self):
        self.model.to(self.device)

        pbar = tqdm(range(self.cfg.exp.num_epochs), desc="Training Epochs")
        for epoch in pbar:
            self.model.train()
            train_loss = self._train_one_epoch()
            if self.args.wandb_log: wandb.log({'train_loss': train_loss})
            pbar.set_postfix(loss=f'{train_loss:.6f}')
            
        # save checkpoint
        print('Saving checkpoint...')
        self._save_checkpoint(epoch, os.path.join(self.args.checkpoint_path, 'final.pt'))

    def evaluate(self):
        return self._evaluate_on_loader(self.eval_loader, 'eval')

    def evaluate_ood(self):
        return self._evaluate_on_loader(self.ood_loader, 'ood')

    def _evaluate_on_loader(self, data_loader, prefix):
        # Ensure vocab exists (potentially redundant if called multiple times, but safe)
        model_inf = self.model
        self._load_checkpoint(os.path.join(self.args.checkpoint_path, 'final.pt'))
        model_inf.to(self.device)
        model_inf.eval()
        
        preds, tgts, eval_loss = self._validate_one_epoch(data_loader, model_inf, prefix)
        print((tgts==0).sum())
        return preds, tgts


    def _load_checkpoint(self, filename):
        ckpt_dict = torch.load(filename, map_location='cpu')
        self.model.load_state_dict(ckpt_dict['MODEL_STATE'])
        self.start_epoch = ckpt_dict['EPOCHS_RUN'] + 1
        self.best_vloss = ckpt_dict['finetune_info']['best_vloss']

    def _save_checkpoint(self, current_epoch, filename):
        ckpt_dict = {
            'MODEL_STATE': self.model.state_dict(),
            'EPOCHS_RUN': current_epoch,
            'hparams': vars(self.hparams),
            'finetune_info': {
                'best_vloss': self.best_vloss,
            },
            'seed': self.seed,
        }

        assert list(ckpt_dict.keys()) == ['MODEL_STATE', 'EPOCHS_RUN', 'hparams', 'finetune_info', 'seed']

        torch.save(ckpt_dict, filename)


    def _set_seed(self, value):
        random.seed(value)
        torch.manual_seed(value)
        np.random.seed(value)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(value)
            torch.cuda.manual_seed_all(value)
            cudnn.deterministic = True
            cudnn.benchmark = False
        
    def _get_lr(self):
        for param_group in self.optimizer.param_groups:
            return param_group['lr']

class RMSELoss(nn.Module):
    def __init__(self, reduction='mean'):
        super(RMSELoss, self).__init__()
        self.reduction = reduction
        self.mse = nn.MSELoss(reduction=reduction)

    def forward(self, pred, target):
        mse_loss = self.mse(pred, target)
        rmse_loss = torch.sqrt(mse_loss)
        return rmse_loss

   
def get_optim_groups(module, keep_decoder=False):
    # setup optimizer
    # separate out all parameters to those that will and won't experience regularizing weight decay
    decay = set()
    no_decay = set()
    whitelist_weight_modules = (torch.nn.Linear,)
    blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
    for mn, m in module.named_modules():
        for pn, p in m.named_parameters():
            fpn = '%s.%s' % (mn, pn) if mn else pn # full param name

            if not keep_decoder and 'decoder' in fpn: # exclude decoder components
                continue

            if pn.endswith('bias'):
                # all biases will not be decayed
                no_decay.add(fpn)
            elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
                # weights of whitelist modules will be weight decayed
                decay.add(fpn)
            elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
                # weights of blacklist modules will NOT be weight decayed
                no_decay.add(fpn)

    # validate that we considered every parameter
    param_dict = {pn: p for pn, p in module.named_parameters()}
    
    # create the pytorch optimizer object
    optim_groups = [
        {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.0},
        {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
    ]

    return optim_groups

def setup_optimizer_and_loss(model, cfg):
    """Set up optimizer and loss function for training."""
    optim_groups = get_optim_groups(model, keep_decoder=bool(cfg.model.train_decoder))
    optimizer = optim.AdamW(optim_groups, lr=cfg.exp.lr, betas=(0.9, 0.99))
    
    if cfg.exp.loss_fn == 'rmse':
        loss_function = RMSELoss()
    elif cfg.exp.loss_fn == 'mae':
        loss_function = nn.L1Loss()
    elif cfg.exp.loss_fn == 'bce':  # For classification
        loss_function = nn.BCEWithLogitsLoss()
    else:
        raise ValueError(f"Unsupported loss_fn: {cfg.exp.loss_fn}")
        
    return optimizer, loss_function



def calculate_rmse(y_true, y_pred):
    """
    Calculate Root Mean Squared Error
    
    Args:
        y_true: Array-like of true values
        y_pred: Array-like of predicted values
    
    Returns:
        RMSE value
    """
    if torch.is_tensor(y_true):
        y_true = y_true.detach().cpu().numpy()
    if torch.is_tensor(y_pred):
        y_pred = y_pred.detach().cpu().numpy()
        
    return np.sqrt(mean_squared_error(y_true, y_pred))

def calculate_mae(y_true, y_pred):
    """
    Calculate Mean Absolute Error
    
    Args:
        y_true: Array-like of true values
        y_pred: Array-like of predicted values
    
    Returns:
        MAE value
    """
    if torch.is_tensor(y_true):
        y_true = y_true.detach().cpu().numpy()
    if torch.is_tensor(y_pred):
        y_pred = y_pred.detach().cpu().numpy()
    
    return mean_absolute_error(y_true, y_pred)

def calculate_r2(y_true, y_pred):
    """
    Calculate R² (coefficient of determination)
    
    Args:
        y_true: Array-like of true values
        y_pred: Array-like of predicted values
    
    Returns:
        R² value
    """
    if torch.is_tensor(y_true):
        y_true = y_true.detach().cpu().numpy()
    if torch.is_tensor(y_pred):
        y_pred = y_pred.detach().cpu().numpy()
    
    return r2_score(y_true, y_pred)

def calculate_pearson(y_true, y_pred):
    """
    Calculate Pearson correlation coefficient
    
    Args:
        y_true: Array-like of true values
        y_pred: Array-like of predicted values
    
    Returns:
        Pearson correlation coefficient and p-value
    """
    if torch.is_tensor(y_true):
        y_true = y_true.detach().cpu().numpy()
    if torch.is_tensor(y_pred):
        y_pred = y_pred.detach().cpu().numpy()
    
    # Reshape if needed
    if len(y_true.shape) > 1:
        y_true = y_true.reshape(-1)
    if len(y_pred.shape) > 1:
        y_pred = y_pred.reshape(-1)
    
    return pearsonr(y_true, y_pred)[0]  # Return only the correlation coefficient

def calculate_spearman(y_true, y_pred):
    """
    Calculate Spearman rank correlation coefficient
    
    Args:
        y_true: Array-like of true values
        y_pred: Array-like of predicted values
    
    Returns:
        Spearman rank correlation coefficient and p-value
    """
    if torch.is_tensor(y_true):
        y_true = y_true.detach().cpu().numpy()
    if torch.is_tensor(y_pred):
        y_pred = y_pred.detach().cpu().numpy()
    
    # Reshape if needed
    if len(y_true.shape) > 1:
        y_true = y_true.reshape(-1)
    if len(y_pred.shape) > 1:
        y_pred = y_pred.reshape(-1)
    
    return spearmanr(y_true, y_pred)[0]  # Return only the correlation coefficient

def calculate_metrics(y_true, y_pred, prefix, task):
    """
    Evaluate model performance with multiple metrics

    Args:
        y_true: Array-like of true values
        y_pred: Array-like of predicted values or logits
        prefix: Optional prefix for metric names
        task: 'regression' or 'classification'

    Returns:
        Dictionary of metrics
    """
    # Convert tensors to numpy
    if torch.is_tensor(y_true):
        y_true = y_true.detach().cpu().numpy()
    if torch.is_tensor(y_pred):
        y_pred = y_pred.detach().cpu().numpy()

    # Reshape if needed
    y_true = y_true.reshape(-1)
    y_pred = y_pred.reshape(-1)

    metrics = {}

    if task == 'classification':
        # Predicted probabilities assumed, threshold at 0.5
        y_prob = y_pred
        #y_true = (y_true > 0.5).astype(int)
        y_label = (y_prob > 0.5).astype(int)

        metrics[f"{prefix}_accuracy"] = accuracy_score(y_true, y_label)
        metrics[f"{prefix}_f1"] = f1_score(y_true, y_label)
        metrics[f"{prefix}_precision"] = precision_score(y_true, y_label)
        metrics[f"{prefix}_recall"] = recall_score(y_true, y_label)

        # AUROC requires both probs and true labels
        try:
            metrics[f"{prefix}_auroc"] = roc_auc_score(y_true, y_prob)
        except ValueError:
            mask = ~np.isnan(y_prob)
            print(f"CAUTION: {(mask==0).sum()} values are masked ")
            metrics[f"{prefix}_auroc"] = roc_auc_score(y_true[mask], y_prob[mask])
    else:
        metrics[f"{prefix}_rmse"] = calculate_rmse(y_true, y_pred)
        metrics[f"{prefix}_mae"] = calculate_mae(y_true, y_pred)
        metrics[f"{prefix}_r2"] = calculate_r2(y_true, y_pred)
        metrics[f"{prefix}_pearson"] = calculate_pearson(y_true, y_pred)
        metrics[f"{prefix}_spearman"] = calculate_spearman(y_true, y_pred)

    return metrics

def save_results(args, results, prefix):
    results_path = os.path.join(args.save_path, f'{prefix}_results.json')
    results = {k: float(v) for k, v in results.items()}
    print(results)
    json.dump(results, open(results_path, 'w'), indent=4)
    