import os
import random
import pickle
import warnings
import torch
import wandb
from torch.utils.data import TensorDataset, DataLoader
from tqdm import tqdm
from .util import *
import sys; sys.path.append('..')
from models.blt import *
from transducers.util import get_deltas_gpu

warnings.filterwarnings('ignore')

def define_model(cfg, input_dim, output_dim):
    if cfg.model.model_type == 'mlp':
        model = MlpPredictor(
                        input_dim=input_dim, 
                        output_dim=output_dim,
                        hidden_dim=cfg.model.hidden_dim,
                        hidden_depth=cfg.model.hidden_depth)
    elif cfg.model.model_type == 'bilinear':
        model = BilinearPredictor(
                        input_dim=input_dim, 
                        output_dim=output_dim,
                        hidden_dim=cfg.model.hidden_dim,
                        feature_dim=cfg.model.feature_dim, 
                        hidden_depth=cfg.model.hidden_depth)
    elif cfg.model.model_type == 'bilinear_scalardelta':
        model = BilinearPredictorScalarDelta(
                        input_dim=input_dim, 
                        output_dim=output_dim,
                        hidden_dim=cfg.model.hidden_dim,
                        feature_dim=cfg.model.feature_dim, 
                        hidden_depth=cfg.model.hidden_depth)
    else:
        print(cfg.model.model_type)
        raise NotImplementedError('model is not implemented.')
    return model

def train_model(args, cfg, dataset, model):
    device = args.device
    print(f"Training on device: {device}")
    X, Y = dataset['train']['reps'], dataset['train']['targets']
        
    X = torch.tensor(X, dtype=torch.float32, device=device)
    Y = torch.tensor(Y, dtype=torch.float32, device=device)
    
    tensor_dataset = TensorDataset(X, Y)
    
    train_loader = DataLoader(
        tensor_dataset, 
        batch_size=cfg.exp.batch_size, 
        shuffle=True,
        num_workers=0  
    )
    
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=cfg.exp.learning_rate)

    epoch_losses = []
    train_deltas = []
    
    pbar = tqdm(range(cfg.exp.num_epochs), desc="Training Epochs")
    for epoch in pbar:
        loss_meter = AvgMeter()
        running_loss = 0
        num_batches = 0
        
        for batch_idx, (batch_x, batch_y) in enumerate(train_loader):
            batch_x = batch_x.to(device)
            batch_y = batch_y.to(device)
            
            batch_x, batch_y, t2_idx = skew_batch_data(batch_x, batch_y, cfg.model.skew_direction)

            optimizer.zero_grad()
            
            if cfg.model.model_type == 'mlp':
                y_pred = model(batch_x)
                loss = torch.mean(torch.linalg.norm(y_pred - batch_y, dim=-1))
            elif 'bilinear' in cfg.model.model_type:
                t1_X = batch_x
                t2_X = batch_x[t2_idx]
                t1_Y = batch_y
                t2_Y = batch_y[t2_idx]
                
                # calculate delta
                deltas = get_deltas_gpu(t1_X, t2_X, cfg.model.similarity_type)
                
                if cfg.model.store_train_deltas and deltas.shape[0] > 0:
                    delta_idx = random.randint(0, deltas.shape[0] - 1)
                    train_deltas.append(deltas[delta_idx].detach().cpu().numpy())
                
                y2_pred = model(t1_X, deltas)
                loss = torch.mean(torch.linalg.norm(y2_pred - t2_Y, dim=-1))
            
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            loss_meter.update(loss.item(), batch_x.size(0))
            num_batches += 1
            
        
        epoch_loss = running_loss / num_batches
        epoch_losses.append(epoch_loss)
        pbar.set_postfix(loss=f'{epoch_loss:.6f}')
        
        # logging wandb
        if args.wandb_log:
            wandb.log({'train_loss': epoch_loss})

        # save checkpoint
        if (epoch + 1) % 2000 == 0 and args.checkpoint_path:
            torch.save(model.state_dict(), os.path.join(args.checkpoint_path, f'{epoch}.pt'))
    
    torch.save(model.state_dict(), os.path.join(args.checkpoint_path, 'final.pt'))
    pickle.dump(train_deltas, open(args.train_deltas_path, 'wb'))
    return model, train_deltas

def test_model(args, cfg, dataset, model, transducer, prefix=None):
    model.eval()
    
    X, Y, smiles = dataset['test_X'], dataset['test_Y'], dataset['test_smiles']
    X = torch.tensor(X, dtype=torch.float32)
    Y = torch.tensor(Y, dtype=torch.float32)
    tensor_dataset = TensorDataset(X, Y)
    
    test_loader = DataLoader(
        tensor_dataset, 
        batch_size=1, 
        shuffle=False,
        num_workers=0  
    )
    
    preds = {'preds': [], 'gt': Y, 'query_smiles': smiles, 'anchor_idxs': [], 'anchor_smiles': []}
    for i, (curr_x, gt_y) in tqdm(enumerate(test_loader), total=len(X)):
        curr_x = curr_x.to(args.device)
        
        closest_train, anchor_idx, anchor_smiles = transducer.choose_anchor(
                curr_x,  use_dom_know_eval=cfg.model.use_dom_know_eval, return_anchor=True
        )
            
        delta = get_deltas_gpu(closest_train, curr_x, cfg.model.similarity_type)
        
        # Evaluate with model
        with torch.no_grad():  # Disable gradients for inference
            if cfg.model.model_type == 'mlp':
                y = model(curr_x).cpu().detach().numpy()[0]
            elif cfg.model.model_type in ['bilinear', 'bilinear_scalardelta']:
                #import pdb; pdb.set_trace()
                y = model(closest_train, delta).cpu().detach().numpy()[0]
                
        preds['preds'].append(y)
        preds['anchor_idxs'].append(anchor_idx)
        preds['anchor_smiles'].append(anchor_smiles)
    
        
    # Compute overall statistics
    preds['preds'] = np.array(preds['preds'])
    results = calculate_metrics(preds['gt'], preds['preds'], prefix=prefix)
    
    if args.wandb_log:
        wandb.summary.update(results)
        
    # Save results
    save_results(args, results, prefix)
    
    return preds

