import os
import torch
import time
import numpy as np
import argparse
import data
from load_model import load_model
from transformers import GPT2TokenizerFast
import torch.nn.functional as F
import sampling
from torch.utils.data import DataLoader, DistributedSampler
from model import utils as mutils
import ot

def create_directory_if_not_exists(directory_path):
    if not os.path.exists(directory_path):
        os.makedirs(directory_path)
        print(f"Directory '{directory_path}' created.")
    else:
        print(f"Directory '{directory_path}' already exists.")

def cycle_loader(dataloader, sampler=None):
    while 1:
        if sampler is not None:
            sampler.set_epoch(np.random.randint(0, 100000))
        for data in dataloader:
            yield data

def J2(args, train_iter, device, graph, model):
    sampling_eps = 0.0001
    losses = []
    j = 0
    flag = False

    # =========================================================================
    # OPTIMAL TRANSPORT SETUP (Matching run_train.py)
    # =========================================================================
    ot_metric = args.ot_metric
    raw_model = model.module if hasattr(model, 'module') else model

    # 1. Explicit Metric Validation and Weight Retrieval
    if ot_metric == 'input_embedding_l2':
        # Use vocab input embeddings
        if hasattr(raw_model.vocab_embed, 'embedding'):
            W = raw_model.vocab_embed.embedding
        else:
            W = raw_model.vocab_embed.weight
    else:
        raise NotImplementedError(f"OT metric '{ot_metric}' is not supported. Use 'input_embedding_l2'.")

    # =========================================================================

    for i in range(args.no_batches):
        
        databatch = next(train_iter)['input_ids'].to(device)
        B, L = databatch.shape
        sourcebatch = graph.sample_limit(databatch.shape).to(databatch)

        try:
            with torch.inference_mode():
                # --- A. Pre-compute Target Info ---
                # (B, L) -> (B, L, Hidden) -> (B, L*Hidden)
                target_emb_flat = F.embedding(databatch, W).view(B, -1)

                # --- B. Compute Cost Matrix M ---
                M = torch.empty(B, B, dtype=torch.float32, device=device)
                t_all = torch.zeros(B, device=device)
                
                chunk = 512 
                for s in range(0, B, chunk):
                    e = min(s + chunk, B)
                    piece = sourcebatch[s:e]
                    
                    if ot_metric == 'input_embedding_l2':
                        # Direct lookup, no forward pass
                        model_out = F.embedding(piece, W)

                    # Compute dot product and store as negative for minimization
                    h_flat = model_out.view(model_out.size(0), -1)
                    dot = torch.mm(h_flat, target_emb_flat.t())
                    M[s:e].copy_(-dot)
                    del model_out

                # --- C. Solver Preprocessing (Centering) ---
                a_np = np.full(B, 1.0 / B, dtype=np.float64)
                b_np = np.full(B, 1.0 / B, dtype=np.float64)

                C_np = M.double().cpu().numpy() / L

                row_mean = C_np.mean(axis=1, keepdims=True)
                col_mean = C_np.mean(axis=0, keepdims=True)
                global_mean = C_np.mean()
                
                C_res = C_np - row_mean - col_mean + global_mean

                if C_res.min() < 0:
                    C_res -= C_res.min()

                # --- D. Exact OT Solver ---
                P_np = ot.emd(a_np, b_np, C_res, numItermax=100000)
                plan = torch.from_numpy(P_np).to(device=device, dtype=torch.float32)
                col_idx = plan.argmax(dim=1)
                
            # Pair batches
            databatch = databatch[col_idx]
            del M, C_np, C_res, plan

        except NotImplementedError as e:
            raise e
        except Exception as e:
            print(f'Error calculating optimal transport: {e}')
            # databatch remains the original big_batch on error

        # --- Standard Evaluation Loop ---
        mini_batch_size = 512
        for minibatch in range(B//mini_batch_size):
            j+=1
            source_batch = sourcebatch[minibatch*mini_batch_size:(minibatch+1)*mini_batch_size]
            data_batch = databatch[minibatch*mini_batch_size:(minibatch+1)*mini_batch_size]
            
            t = (1 - 2*sampling_eps) * torch.rand(data_batch.shape[0], device=data_batch.device)+sampling_eps
            perturbed_batch = graph.sample_transition(source_batch, data_batch, t[:, None])
            log_score_fn = mutils.get_score_fn(model, train='True', sampling=False)
            log_score = log_score_fn(perturbed_batch, t)
            log_score = log_score.view(-1, log_score.shape[-1])
            data_batch = data_batch.view(-1)  
            criterion = torch.nn.CrossEntropyLoss(reduction='none')
            
            mask = perturbed_batch!=data_batch.reshape(mini_batch_size,L)
            mask = mask*1
            loss = criterion(log_score, data_batch)
            loss = loss.view(perturbed_batch.shape[0], perturbed_batch.shape[1])*mask/(1-t[:, None])  
            loss =  loss.sum(-1).mean().float()
            
            losses.append(loss.cpu().numpy())
            
            if torch.isinf(torch.tensor(np.exp(np.array(losses).mean()/args.length))).item():
                break
            if j%100==0:
                print('iter:', j)
                print('mean:', np.exp(np.array(losses).mean()/args.length))
            if j%1000==0:
                np.save(args.model_path+'_eval/'+'lossesJ1_'+'dataset_'+args.dataset+'_length_'+str(args.length)+'.npy', np.array(losses)/args.length)
            if j==1500:
                flag=True
                break
        if flag:
            break

def main():
    parser = argparse.ArgumentParser(description="Generate some samples")
    parser.add_argument("--model_path", default="x", type=str)
    # Defaulting to openwebtext and J2 as per context
    parser.add_argument("--dataset", default="lambada", type=str) 
    parser.add_argument("--length", default=128, type=int)
    parser.add_argument("--batch_size", type=int, default=512)
    parser.add_argument("--perturbed_points_nr", type=int, default=128*(50*1024))
    parser.add_argument("--cache_dir", type=str, default='data')
    parser.add_argument("--J", type=str, default='2') 
    
    # Updated to include ALL metrics from run_train.py
    parser.add_argument("--ot_metric", type=str, default='input_embedding_l2', 
                        choices=['input_embedding_l2'])
    
    args = parser.parse_args()
    try:
        args.no_batches = args.perturbed_points_nr//args.batch_size
    except:
        print('args.perturbed_points_nr should be bigger than args.batch_size')
    device = torch.device('cuda')

    create_directory_if_not_exists(args.model_path+'_eval/')
    
    # Passing "test" to ensure OWT logic (if using OpenWebText) or general validation logic triggers
    train_set = data.get_dataset(args.dataset, "test", cache_dir=args.cache_dir, block_size=args.length)

    train_loader = cycle_loader(DataLoader(
        train_set,
        batch_size=args.batch_size,
        shuffle=False,   
        num_workers=4,
        pin_memory=True,
        persistent_workers=True,
    ))
    train_iter = iter(train_loader)

    print(f"The size of the dataset: {len(train_set)}")

    with torch.no_grad():
        model, graph = load_model(args.model_path, device)
        J2(args, train_iter, device, graph, model)

    print('dataset:', args.dataset, 'model:', args.model_path)

if __name__=="__main__":
    main()
