import torch
import fire
from datasets import load_dataset
import json, os
from ss_utils import _landmark_rows_leverage, _landmark_p_to_full_krr, _estimate_projected_grads_krrf

def get_lr(dset):
    if dset == 'gsm8k':
        return 6e-6
    elif dset == 'sql':
        return 3e-5
    elif dset == 'viggo':
        return 4e-5
    else:
        raise ValueError(f'Unknown dataset name: {dset}')

def get_np_from_method(method):
    # This function extracts the value of 'np' from the method name, e.g., 'ss_cluster_np10_shortwarm'
    import re
    match = re.search(r'_npoints(\d+)', method)
    if match:
        out = int(match.group(1))
    else:
        out = 20 # default to 20% of the data
    print(f'Using {out}% of the data as npoints')
    return out

def parse_lmbda(lmbda, losses, apx_losses, r):
    if lmbda == 'auto':
        return apx_losses.mean().item() / r.mean().item()
    elif lmbda == 'lambda90':
        if losses is None:
            raise ValueError('true losses are required for lambda90')
        return ((losses - apx_losses).abs() / (r**2 + 1e-8)).quantile(0.9).item()
    else:
        return lmbda

def main(
    out_path=None,
    weights_path=None,
    dset='gsm8k',
    size=2000,
    method='uniform',
    seed=42,
    device='cuda:0',
    lmbda=0,
    train_default=True
):

    losses_path=f'{dset}_losses2.pt'
    grads_path=f'{dset}_grad2.pt'

    if 'gtr' in method:
        embds_path=f'{dset}_gtr_embds.pt'
    else:
        embds_path=f'{dset}_bert_embds.pt'
    
    print(f'Using embeddings from {embds_path}')

    if out_path is None:
        out_path = f'./data/{dset}/subset_{size}_{method}_lambda{lmbda}_seed{seed}.jsonl'
    if weights_path is None:
        weights_path = f'./data/{dset}/subset_{size}_{method}_lambda{lmbda}_seed{seed}_weights.pt'
    
    os.makedirs(os.path.dirname(out_path), exist_ok=True)

    if dset == 'sql':
        dataset = load_dataset('json', data_files="./data/sql/train.jsonl", split="train")
    elif dset == 'viggo':
        dataset = load_dataset('GEM/viggo', split='train')
    elif dset == 'gsm8k':
        dataset = load_dataset('gsm8k', 'main', split='train')
    else:
        raise ValueError(f'Unknown dataset name: {dset}')

    if method == 'full':
        size = len(dataset)
        indices = list(range(len(dataset)))
        weights = (len(dataset) / size) * torch.ones(size)

    elif 'uniform' in method:
        indices = torch.randint(0, len(dataset), (size,))
        weights = (len(dataset) / size) * torch.ones(size)
    
    # elif 'uniffx' in method:
    #     indices = torch.randint(0, len(dataset), (size,))
    #     weights = (len(dataset) / size) * torch.ones(size)

    elif 'ss18' in method:
        embds = torch.load(embds_path).numpy()
        from cuml.cluster import KMeans
        torch.manual_seed(seed)

        num_clusters = size
        print(f'Using {num_clusters} clusters')
        kmeans = KMeans(n_clusters=num_clusters, random_state=seed, n_init=10)
        cluster_labels = kmeans.fit_predict(embds)
        cluster_centers = kmeans.cluster_centers_

        # map centers to closest actual samples
        embds = torch.tensor(embds, device=device)
        cluster_centers = torch.tensor(cluster_centers, device=device)
        dists = torch.cdist(embds, cluster_centers) # shape: (len(dataset), n_clusters)
        center_sample_ids = torch.argmin(dists, dim=0) # shape: (n_clusters,)

        indices = center_sample_ids.detach().cpu().numpy()
        weights = (len(dataset) / size) * torch.ones(size)
    
    elif 'leverage' in method:
        # load embeddings and import kmeans
        embds = torch.load(embds_path, map_location=device) # pre-calculated bert embeddings
        torch.manual_seed(seed)
        num_landmarks = size
        landmark_idx = _landmark_rows_leverage(embds, num_landmarks)

        indices = landmark_idx
        weights = (len(dataset) / size) * torch.ones(size)

    
    elif 'ss_cluster' in method:

        # load embeddings and import kmeans
        embds = torch.load(embds_path).numpy() # pre-calculated bert embeddings
        # from sklearn.cluster import KMeans
        from cuml.cluster import KMeans
        torch.manual_seed(seed)
        
        
        num_clusters = int(len(dataset) * get_np_from_method(method) / 100)
        print(f'Using {num_clusters} clusters')
        kmeans = KMeans(n_clusters=num_clusters, random_state=seed, n_init=10)

        # fit kmeans
        cluster_labels = kmeans.fit_predict(embds)
        cluster_centers = kmeans.cluster_centers_
        
        # map centers to closest actual samples
        embds = torch.tensor(embds, device=device)
        cluster_centers = torch.tensor(cluster_centers, device=device)
        dists = torch.cdist(embds, cluster_centers) # shape: (len(dataset), n_clusters)
        center_sample_ids = torch.argmin(dists, dim=0) # shape: (n_clusters,)

        # load losses of the center samples
        if 'gradnorm' in method:
            grads = torch.load(grads_path, map_location=device)
            losses = grads.norm(dim=1)
        else:
            losses = torch.load(losses_path, map_location=device)
        
        center_losses = losses[center_sample_ids]
        apx_losses = center_losses[cluster_labels]

        # compute scores
        r = (embds - embds[center_sample_ids[cluster_labels]]).norm(dim=1) # shape: (len(dataset),)
        lmbda = parse_lmbda(lmbda, losses, apx_losses, r)
        s = apx_losses + lmbda * r
        p = s / s.sum()

        # sample and assign weights
        indices = torch.multinomial(p, size).detach().cpu().numpy()
        weights = 1 / (p[indices] * size).detach().cpu().numpy()

    elif 'ss_lowrank' in method:
        # load embeddings and import kmeans
        embds = torch.load(embds_path, map_location=device) # pre-calculated bert embeddings
        torch.manual_seed(seed)

        # load losses of the center samples
        if 'gradnorm' in method:
            grads = torch.load(grads_path, map_location=device)
            losses = grads.norm(dim=1)
        else:
            losses = torch.load(losses_path, map_location=device)
        
        num_landmarks = int(len(dataset) * get_np_from_method(method) / 100)
        landmark_idx = _landmark_rows_leverage(embds, num_landmarks)
        landmark_losses = losses[landmark_idx]
        projected_embds = _estimate_projected_grads_krrf(embds, landmark_idx, embds[landmark_idx])
        projected_losses = _landmark_p_to_full_krr(landmark_idx, embds, landmark_losses)

        # compute scores
        r = (embds - projected_embds).norm(dim=1) # shape: (len(dataset),)
        lmbda = parse_lmbda(lmbda, losses, projected_losses, r)
        s = projected_losses + lmbda * r
        p = s / s.sum()

        # sample and assign weights
        indices = torch.multinomial(p, size).detach().cpu().numpy()
        weights = 1 / (p[indices] * size).detach().cpu().numpy()

    elif 'ss_ideal_grad_dir' in method:
        grads = torch.load(grads_path, map_location=device).float()
        grads = grads / grads.norm(dim=1, keepdim=True) # normalized gradients

        num_landmarks = int(len(dataset) * get_np_from_method(method) / 100)
        landmark_idx = _landmark_rows_leverage(grads, num_landmarks)
        projected_grads = _estimate_projected_grads_krrf(grads, landmark_idx, grads[landmark_idx])
        projected_losses = projected_grads.norm(dim=1)

        # compute scores
        r = (grads - projected_grads).norm(dim=1) # shape: (len(dataset),)
        lmbda = parse_lmbda(lmbda, None, projected_losses, r)
        s = projected_losses + lmbda * r
        p = s / s.sum()

        # sample and assign weights
        indices = torch.multinomial(p, size).detach().cpu().numpy()
        weights = 1 / (p[indices] * size).detach().cpu().numpy()
    


    elif 'ss_mix' in method:
        lmbda_cluster, lmbda_lowrank = [float(x) for x in lmbda.split('-')]

        # load embeddings and import kmeans
        embds = torch.load(embds_path).numpy() # pre-calculated bert embeddings
        # from sklearn.cluster import KMeans
        from cuml.cluster import KMeans
        torch.manual_seed(seed)
        
        
        num_clusters = int(len(dataset) * get_np_from_method(method) / 100)
        print(f'Using {num_clusters} clusters')
        kmeans = KMeans(n_clusters=num_clusters, random_state=seed, n_init=10)

        # fit kmeans
        cluster_labels = kmeans.fit_predict(embds)
        cluster_centers = kmeans.cluster_centers_
        
        # map centers to closest actual samples
        embds = torch.tensor(embds, device=device)
        cluster_centers = torch.tensor(cluster_centers, device=device)
        dists = torch.cdist(embds, cluster_centers) # shape: (len(dataset), n_clusters)
        center_sample_ids = torch.argmin(dists, dim=0) # shape: (n_clusters,)

        # load losses of the center samples
        if 'gradnorm' in method:
            grads = torch.load(grads_path, map_location=device)
            losses = grads.norm(dim=1)
        else:
            losses = torch.load(losses_path, map_location=device)
        
        center_losses = losses[center_sample_ids]
        apx_losses = center_losses[cluster_labels]

        # compute scores
        r = (embds - embds[center_sample_ids[cluster_labels]]).norm(dim=1) # shape: (len(dataset),)
        lmbda_cluster = parse_lmbda(lmbda_cluster, losses, apx_losses, r)
        s = apx_losses + lmbda_cluster * r
        p_cluster = s / s.sum()

        # load embeddings and import kmeans
        embds = torch.load(embds_path, map_location=device) # pre-calculated bert embeddings
        torch.manual_seed(seed)

        # load losses of the center samples
        if 'gradnorm' in method:
            grads = torch.load(grads_path, map_location=device)
            losses = grads.norm(dim=1)
        else:
            losses = torch.load(losses_path, map_location=device)
        
        num_landmarks = int(len(dataset) * get_np_from_method(method) / 100)
        landmark_idx = _landmark_rows_leverage(embds, num_landmarks)
        landmark_losses = losses[landmark_idx]
        projected_embds = _estimate_projected_grads_krrf(embds, landmark_idx, embds[landmark_idx])
        projected_losses = _landmark_p_to_full_krr(landmark_idx, embds, landmark_losses)

        # compute scores
        r = (embds - projected_embds).norm(dim=1) # shape: (len(dataset),)
        lmbda_lowrank = parse_lmbda(lmbda_lowrank, losses, projected_losses, r)
        s = projected_losses + lmbda_lowrank * r
        p_lowrank = s / s.sum()

        p = 0.5 * p_cluster + 0.5 * p_lowrank
        # sample and assign weights
        indices = torch.multinomial(p, size).detach().cpu().numpy()
        weights = 1 / (p[indices] * size).detach().cpu().numpy()


    elif 'ss_2mix' in method:
        embds = torch.load(embds_path, map_location=device) # pre-calculated bert embeddings
        torch.manual_seed(seed)
        
        num_landmarks = int(len(dataset) * get_np_from_method(method) / 100)
        landmark_idx = _landmark_rows_leverage(embds, num_landmarks)
        projected_embds = _estimate_projected_grads_krrf(embds, landmark_idx, embds[landmark_idx])

        embds = projected_embds.detach().cpu().numpy()

        from cuml.cluster import KMeans
        torch.manual_seed(seed)
        
        
        num_clusters = int(len(dataset) * get_np_from_method(method) / 100)
        print(f'Using {num_clusters} clusters')
        kmeans = KMeans(n_clusters=num_clusters, random_state=seed, n_init=10)

        # fit kmeans
        cluster_labels = kmeans.fit_predict(embds)
        cluster_centers = kmeans.cluster_centers_
        
        # map centers to closest actual samples
        embds = torch.tensor(embds, device=device)
        cluster_centers = torch.tensor(cluster_centers, device=device)
        dists = torch.cdist(embds, cluster_centers) # shape: (len(dataset), n_clusters)
        center_sample_ids = torch.argmin(dists, dim=0) # shape: (n_clusters,)

        # load losses of the center samples
        if 'gradnorm' in method:
            grads = torch.load(grads_path, map_location=device)
            losses = grads.norm(dim=1)
        else:
            losses = torch.load(losses_path, map_location=device)
        
        center_losses = losses[center_sample_ids]
        apx_losses = center_losses[cluster_labels]

        # compute scores
        r = (embds - embds[center_sample_ids[cluster_labels]]).norm(dim=1) # shape: (len(dataset),)
        lmbda = parse_lmbda(lmbda, losses, apx_losses, r)
        s = apx_losses + lmbda * r
        p = s / s.sum()

        indices = torch.multinomial(p, size).detach().cpu().numpy()
        weights = 1 / (p[indices] * size).detach().cpu().numpy()

    else:
        raise ValueError(f'Invalid method: {method}')
    


    dataset = dataset.select(indices)
    with open(out_path, 'w') as f:
        for example in dataset:
            f.write(json.dumps(example) + '\n')
    
    torch.save(weights, weights_path)

    if train_default:
        command = f'bash finetune_subgsm.sh SUBSET_PATH={out_path} WEIGHTS_PATH={weights_path} SEED={seed} DATASET={dset} LR={get_lr(dset)}'
        if 'shortwarm' in method:
            command += ' WARMUP=5'
        print(command)
        os.system(command)


if __name__ == '__main__':
    fire.Fire(main)