import os
import torch
import time
import numpy as np
import argparse
import data
from load_model import load_model
from transformers import GPT2TokenizerFast
import torch.nn.functional as F
import sampling
from torch.utils.data import DataLoader, DistributedSampler
from model import utils as mutils
import ot

def create_directory_if_not_exists(directory_path):
    if not os.path.exists(directory_path):
        os.makedirs(directory_path)
        print(f"Directory '{directory_path}' created.")
    else:
        print(f"Directory '{directory_path}' already exists.")

def cycle_loader(dataloader, sampler=None):
    while 1:
        if sampler is not None:
            sampler.set_epoch(np.random.randint(0, 100000))
        for data in dataloader:
            yield data

def J2(args, train_iter, device, noise, graph, model):
    sampling_eps=0.0001
    losses = []
    j=0
    flag = False
    for i in range(args.no_batches):
        
        databatch = next(train_iter)['input_ids'].to(device)
        B, L = databatch.shape
        sourcebatch = graph.sample_limit(databatch.shape).to(databatch)


        use_optimal_transport=True

        if use_optimal_transport:
            try:
                with torch.inference_mode():

                    B, L = sourcebatch.shape
                    device = sourcebatch.device

                    # 1) Embed and cache the reference batch (databatch)
                    databatch_emb   = model.vocab_embed(databatch).reshape(B, -1)   # (B, D)
                    databatch_norm  = databatch_emb.pow(2).sum(1).view(1, -1)              # (1, B)

                    # 2) Allocate the cost matrix once
                    M = torch.empty(B, B, dtype=torch.float32, device=device)

                    chunk = 512

                    # 3) Stream over sourcebatch
                    for s in range(0, B, chunk):
                        e = min(s + chunk, B)

                        src_emb  = model.vocab_embed(sourcebatch[s:e]).reshape(e - s, -1)  # (c, D)
                        src_norm = src_emb.pow(2).sum(1).view(-1, 1)                              # (c, 1)

                        # C = ‖xᵢ‖² + ‖yⱼ‖²   (broadcasted)
                        C = src_norm + databatch_norm                       # (c, B)

                        # M_chunk = C − 2·⟨xᵢ, yⱼ⟩   written in-place
                        torch.addmm(C, src_emb, databatch_emb.T,
                                    beta=1.0, alpha=-2.0, out=M[s:e])

                        del src_emb, src_norm

                    M.clamp_(min=0)         
                    M.div_(M.max()) 

                    a = np.ones(B) / B
                    b = np.ones(B) / B


                
                    a_np = np.full(B, 1.0 / B, dtype=np.float64)
                    b_np = np.full(B, 1.0 / B, dtype=np.float64)

                    C_np = M.double().cpu().numpy()

                    P_np = ot.emd(a_np, b_np, C_np)
                    plan = torch.from_numpy(P_np).to(device=device, dtype=torch.float32)  # [B, B]

                    # 3) Get permutation indices
                    col_idx = plan.argmax(dim=1)      # [B]
                    row_idx = torch.arange(B, device=device)

                    sourcebatch_matched = sourcebatch[row_idx]
                    databatch_matched = databatch[col_idx]

                    sourcebatch = sourcebatch_matched
                    databatch = databatch_matched

                # sourcebatch = sourcebatch[sourcebatch_indices]
                # databatch = databatch[databatch_indices]
                del M
                #print('Optimal transport calculated successfully')
            except:
                databatch = databatch
                print('Error calculating optimal transport, continuing with independent sampling this batch')
        #print(sourcebatch[:,:1].flatten())
        else:
            databatch = databatch

        mini_batch_size = 512

        for minibatch in range(B//mini_batch_size):
            j+=1
            source_batch = sourcebatch[minibatch*mini_batch_size:(minibatch+1)*mini_batch_size]
            data_batch = databatch[minibatch*mini_batch_size:(minibatch+1)*mini_batch_size]
            
            t = (1 - 2*sampling_eps) * torch.rand(data_batch.shape[0], device=data_batch.device)+sampling_eps
            perturbed_batch = graph.sample_transition(source_batch, data_batch, t[:, None])
            log_score_fn = mutils.get_score_fn(model, train='True', sampling=False)
            log_score = log_score_fn(perturbed_batch, t)
            log_score = log_score.view(-1, log_score.shape[-1])
            data_batch = data_batch.view(-1)  
            criterion = torch.nn.CrossEntropyLoss(reduction='none')
            #perturbed_batch = (perturbed_batch>50256)*50257 + perturbed_batch * (perturbed_batch<=50256)
            mask = perturbed_batch!=data_batch.reshape(mini_batch_size,L)
            mask = mask*1
            loss = criterion(log_score, data_batch)
            loss = loss.view(perturbed_batch.shape[0], perturbed_batch.shape[1])*mask/(1-t[:, None])  
            loss =  loss.sum(-1).mean().float()
            #print(loss)
            
            losses.append(loss.cpu().numpy())
            
            if torch.isinf(torch.tensor(np.exp(np.array(losses).mean()/args.length))).item():
                break
            if j%100==0:
                print('iter:', j)
                print('mean:', np.exp(np.array(losses).mean()/args.length))
            if j%1000==0:
                np.save(args.model_path+'_eval/'+'lossesJ1_'+'dataset_'+args.dataset+'_length_'+str(args.length)+'.npy', np.array(losses)/args.length)
            if j==1500:
                flag=True
                break
        if flag:
            break

def main():
    parser = argparse.ArgumentParser(description="Generate some samples")
    parser.add_argument("--model_path", default="x", type=str)
    parser.add_argument("--dataset", default="lm1b", type=str)
    parser.add_argument("--length", default=128, type=int)
    parser.add_argument("--batch_size", type=int, default=512)
    parser.add_argument("--perturbed_points_nr", type=int, default=128*(50*1024))
    parser.add_argument("--cache_dir", type=str, default='data')
    parser.add_argument("--J", type=str, default='2')
    args = parser.parse_args()
    try:
        args.no_batches = args.perturbed_points_nr//args.batch_size
    except:
        print('args.perturbed_points_nr should be bigger than args.batch_size')
    device = torch.device('cuda')

    create_directory_if_not_exists(args.model_path+'_eval/')
    train_set = data.get_dataset(args.dataset, "train", cache_dir=args.cache_dir, block_size=args.length)

    train_loader = cycle_loader(DataLoader(
        train_set,
        batch_size=args.batch_size,
        shuffle=False,   # Enable shuffling of the data
        num_workers=4,
        pin_memory=True,
        persistent_workers=True,
    ))
    train_iter = iter(train_loader)

    print(f"The size of the dataset: {len(train_set)}")

    with torch.no_grad():
        
        model, graph, noise = load_model(args.model_path, device)
        J2(args, train_iter, device, noise, graph, model)

    print('dataset:', args.dataset, 'model:', args.model_path)
if __name__=="__main__":
    main()
