import random
import torch
from tqdm import tqdm


def get_eps_step(eps):
    if eps <= 0.015:
        return 0.001
    elif eps <= 0.15:
        return 0.01
    elif eps <= 0.5:
        return 0.05
    else:
        return 0.1
      
def purification(dataset, eps_0, eps_max, collate_fn, batch_size, sum_n, model, device, synthetic=None):
      eps = eps_0
      D = dataset.get_proofs()
      D_loader = torch.utils.data.DataLoader(
        dataset,
        collate_fn=collate_fn,
        batch_size=batch_size,
        shuffle=False
      )
      trial = 1
      with torch.no_grad():
        model.eval()
        while True:
            proof_changes = 0
            og_num_proofs = 0
            gt_removed = 0
            total_processed = 0
            D_new = []
            iter = tqdm(D_loader, f"Purifying Dataset (Trial {trial}, eps={eps}, changes={proof_changes}, gt_removed={gt_removed})")
            for batch in iter:
                if len(batch) == 4:
                    data, target, labels, proofs = batch
                else:
                    data, target, proofs = batch
                    labels = None
                batch_size = len(data[0])
                imgs = ()
                for x in range(sum_n):
                    imgs = imgs + (data[x].to(device),)
                target = target.to(device)
                ops = [ model(imgs[x]) for x in range(sum_n) ]
                preds = [ op.data.max(1, keepdim=True)[1] for op in ops ]

                for bidx in range(batch_size):
                    original_proof = proofs[bidx]
                    if synthetic is not None and synthetic > 0:
                        # print(synthetic, len(original_proof))
                        sampled_proofs = random.sample(original_proof, synthetic) if len(original_proof) > synthetic else random.sample(original_proof, len(original_proof)-1)
                        if list(labels[bidx]) not in sampled_proofs:
                            sampled_proofs.append(list(labels[bidx]))

                        proofs[bidx] = sampled_proofs
                    else:
                        for i in range(len(ops)):
                            pred = preds[i][bidx]
                            for col in range(len(ops[i][bidx])):
                                if ops[i][bidx][pred] - ops[i][bidx][col] > eps:
                                    proofs[bidx] = [ p for p in proofs[bidx] if p[i] != col ]
                                    # print(f"Proofs: {ops[i][bidx]}")
                                    # exit()
                                    # if labels[bidx][i] == col:
                                    #   gt_removed += 1
                        proof_changes += len(original_proof) - len(proofs[bidx])
                        og_num_proofs += len(original_proof)
                        if labels is not None and list(labels[bidx]) not in proofs[bidx]:
                            # print(labels[bidx], proofs[bidx])
                            # exit()
                            gt_removed += 1
                            if synthetic is not None:
                                proofs[bidx].append(list(labels[bidx]))
                    total_processed += 1
                    
                    D_new.append(proofs[bidx])
                    
                iter.set_description(f"Purifying Dataset (Trial {trial}, eps={eps}, changes={proof_changes}/{og_num_proofs}, gt_removed={gt_removed}/{total_processed})")

            if D == D_new and eps < eps_max and eps > 0:
                eps = eps - get_eps_step(eps) # check
                trial += 1
            else:
                print(f"Trial {trial} completed with {proof_changes} changes")
                
                break
        return dataset.update_proofs_samplewise(D_new)