import os
import torch
import time
import numpy as np
import argparse
import data
from load_model import load_model
from transformers import GPT2TokenizerFast
import torch.nn.functional as F
import sampling
from torch.utils.data import DataLoader, DistributedSampler
from model import utils as mutils
import ot

def create_directory_if_not_exists(directory_path):
    if not os.path.exists(directory_path):
        os.makedirs(directory_path)
        print(f"Directory '{directory_path}' created.")
    else:
        print(f"Directory '{directory_path}' already exists.")

def cycle_loader(dataloader, sampler=None):
    while 1:
        if sampler is not None:
            sampler.set_epoch(np.random.randint(0, 100000))
        for data in dataloader:
            yield data

def J1(args, train_iter, device, noise, graph, model):
    sampling_eps=0.0001
    losses = []
    j=0
    for i in range(args.no_batches):
        
        databatch = next(train_iter)['input_ids'].to(device)
        B, L = databatch.shape
        sourcebatch = graph.sample_limit(databatch.shape).to(databatch)


        use_optimal_transport=True

        if use_optimal_transport:
            try:
                with torch.inference_mode():

                    B, L = sourcebatch.shape
                    device = sourcebatch.device

                    # 1) Embed and cache the reference batch (databatch)
                    databatch_emb   = model.vocab_embed(databatch).reshape(B, -1)   # (B, D)
                    databatch_norm  = databatch_emb.pow(2).sum(1).view(1, -1)              # (1, B)

                    # 2) Allocate the cost matrix once
                    M = torch.empty(B, B, dtype=torch.float32, device=device)

                    chunk = 32

                    # 3) Stream over sourcebatch
                    for s in range(0, B, chunk):
                        e = min(s + chunk, B)

                        src_emb  = model.vocab_embed(sourcebatch[s:e]).reshape(e - s, -1)  # (c, D)
                        src_norm = src_emb.pow(2).sum(1).view(-1, 1)                              # (c, 1)

                        # C = ‖xᵢ‖² + ‖yⱼ‖²   (broadcasted)
                        C = src_norm + databatch_norm                       # (c, B)

                        # M_chunk = C − 2·⟨xᵢ, yⱼ⟩   written in-place
                        torch.addmm(C, src_emb, databatch_emb.T,
                                    beta=1.0, alpha=-2.0, out=M[s:e])

                        del src_emb, src_norm

                    # 4) Same Sinkhorn post-processing
                    M.clamp_(min=0)              # M.clamp_() is in-place
                    M.div_(M.max()) 
                    print('M shape:', M.shape)
                    # print('Max:', M.max())
                    # print('cost matrix:', M)

                    # print('Max:', M.max())
                    # print('Min:', M.min())
                    # print('Avg:', M.mean())

                    a = np.ones(B) / B
                    b = np.ones(B) / B

                    ot_eps = 0.0101

                    
                    sinkhorn_plan = ot.sinkhorn(a, b, M.cpu().numpy(), ot_eps)
                    sinkhorn_plan = torch.from_numpy(sinkhorn_plan)
                    # print('plan:', sinkhorn_plan)
                    # print('vertplan: ', sinkhorn_plan.sum(1))
                    # print('totprob: ', sinkhorn_plan.sum())

                    # print('planmax:', sinkhorn_plan.max(1))
                    # print('planmin:', sinkhorn_plan.min(1))
                    # print('planavg:', sinkhorn_plan.mean(1))

                    flattened_plan = sinkhorn_plan.flatten()
                    num_samples = B
                    indices = torch.multinomial(flattened_plan, num_samples, replacement=True)

                    sourcebatch_indices = indices // sinkhorn_plan.shape[1]
                    databatch_indices = indices % sinkhorn_plan.shape[1]

                sourcebatch = sourcebatch[sourcebatch_indices]
                databatch = databatch[databatch_indices]
                del M
                print('Optimal transport calculated successfully')
            except:
                databatch = databatch
                print('Error calculating optimal transport, continuing with independent sampling this batch')
            #print(sourcebatch[:,:1].flatten())
        else:
            databatch = databatch

        mini_batch_size = 32
        for minibatch in range(B//mini_batch_size):
            j+=1
            source_batch = sourcebatch[minibatch*mini_batch_size:(minibatch+1)*mini_batch_size]
            data_batch = databatch[minibatch*mini_batch_size:(minibatch+1)*mini_batch_size]

            
            
            t = (1 - 2*sampling_eps) * torch.rand(data_batch.shape[0], device=data_batch.device)+sampling_eps
            perturbed_batch = graph.sample_transition(source_batch, data_batch, t[:, None])

            mask= perturbed_batch!=data_batch

            log_score_fn = mutils.get_score_fn(model, train='True', sampling=False)
            log_score = log_score_fn(perturbed_batch, t)

            probs = torch.softmax(log_score, dim=-1)

            loss1 = mask * (torch.log(torch.gather(probs, dim=-1, index=data_batch.unsqueeze(-1) )+0.0000000001)+1).squeeze()
            loss2 = 1 - torch.gather(probs, dim=-1, index=perturbed_batch.unsqueeze(-1)).squeeze()

            loss = -loss1 + loss2

            loss =  (loss/(1-t[:, None])).sum(-1).mean()
            
            losses.append(loss.cpu().numpy())
            
            if torch.isinf(torch.tensor(np.exp(np.array(losses).mean()/args.length))).item():
                break
            if j%100==0:
                print('iter:', j)
                print('mean:', np.exp(np.array(losses).mean()/args.length))
            if j%1000==0:
                np.save(args.model_path+'_eval/'+'lossesJ1_'+'dataset_'+args.dataset+'_length_'+str(args.length)+'.npy', np.array(losses)/args.length)



def main():
    parser = argparse.ArgumentParser(description="Generate some samples")
    parser.add_argument("--model_path", default="exp_local/openwebtext/x", type=str)
    parser.add_argument("--dataset", default="ptb", type=str)
    parser.add_argument("--length", default=128, type=int)
    parser.add_argument("--batch_size", type=int, default=512)
    parser.add_argument("--perturbed_points_nr", type=int, default=128*(50*1024))
    parser.add_argument("--cache_dir", type=str, default='data')
    parser.add_argument("--J", type=str, default='1')
    args = parser.parse_args()
    try:
        args.no_batches = args.perturbed_points_nr//args.batch_size
    except:
        print('args.perturbed_points_nr should be bigger than args.batch_size')
    device = torch.device('cuda')

    create_directory_if_not_exists(args.model_path+'_eval/')
    train_set = data.get_dataset(args.dataset, "train", cache_dir=args.cache_dir, block_size=args.length)

    train_loader = cycle_loader(DataLoader(
        train_set,
        batch_size=args.batch_size,
        shuffle=False,   # Enable shuffling of the data
        num_workers=4,
        pin_memory=True,
        persistent_workers=True,
    ))
    train_iter = iter(train_loader)

    print(f"The size of the dataset: {len(train_set)}")

    with torch.no_grad():
        
        model, graph, noise = load_model(args.model_path, device)
        J1(args, train_iter, device, noise, graph, model)
if __name__=="__main__":
    main()
