import os
from typing import List
import pickle
import torch

from ddlm.sampler.early_exit import LogStrategy, Strategy, NoStrategy
from ddlm.sampler.euler import (
    get_sigmas_karras,
    sample_euler,
)
from tqdm.auto import tqdm, trange
import wandb


@torch.no_grad()
def generate_with_conditioning_mask( # !!!!!!
    texts: List[str],
    conditioning_mask: torch.BoolTensor,
    tokenizer,
    model,
    time_wrapping,
    continuation_number,
    length: int = 64,
    num_steps: int = 100,
    prefix_length: int = 0,
    rho: float = 8,
    batch_size: int = 256,
    device="cuda:0",
    use_time_wrapping: bool = True,
    t_min: float = 1,
    t_max: float = 10,
    self_conditioning: bool = True,
    s_churn: float = 0.0,
    simplified_inputs: bool = False,
    renormalization: bool = False,
    initial_noise_scale: float = 1.0,
    timedelta: float = 0.0,
    artifact=None,
    interpolate: bool = False,
    strategy: Strategy = NoStrategy,
    outputs_sweep_name: str = "",
    download_ouputs: bool = False,
):
    sigmas = get_sigmas_karras(
        n=num_steps,
        sigma_min=t_min,
        sigma_max=t_max,
        rho=rho,
        tw_model=time_wrapping if use_time_wrapping else None,
        device=device,
        timedelta=timedelta,
    )

    torch.save(sigmas, "sigmas.pt")
    if artifact is None:
        wandb.save("sigmas.pt")
    else:
        artifact.add_file("sigmas.pt")

    print(f"Going to generate {len(texts)} sequences...")
    outputs_steps_batches = []
    batch_i = 0
    for start_idx in range(0, len(texts), batch_size):
        encoded = tokenizer(
            texts[start_idx : start_idx + batch_size],
            max_length=length,
            padding='max_length',
            truncation=True,
            return_tensors="pt",
        )
        input_ids = encoded["input_ids"].to(device)
        c_conditioning_mask = conditioning_mask.repeat(input_ids.size(0), 1).to(device)

        if not download_ouputs:
            encoded_steps, metrics = sample_euler(
                model=model,
                sigmas=sigmas,
                input_ids=input_ids,
                conditioning_mask=c_conditioning_mask,
                batch_index=batch_i,
                disable=True,
                s_tmin=t_min,
                s_tmax=t_max,
                s_churn=s_churn,
                self_conditioning=self_conditioning,
                initial_noise_scale=initial_noise_scale,
                simplified_inputs=simplified_inputs,
                renormalization=renormalization,
                interpolate=interpolate,
                strategy=strategy,
                continuation_number=continuation_number,
            )
        else:
            metrics = {"observed_steps": 0}
            api = wandb.Api()
            sweep = api.sweep(outputs_sweep_name)

            for run in sweep.runs:
                if (run.config["prefix_length"] == prefix_length) and (run.config["num_steps"] == num_steps):
                    for file in run.files():
                        file.download(replace=True)
            
            with open("/app/encoded.pickle", 'rb') as f:
                encoded_steps = pickle.load(f)

        c_conditioning_mask = c_conditioning_mask.cpu()
        outputs_steps = []
        for encoded in encoded_steps:
            encoded_accumulator = []
            encoded_generations_accumulator = []
            generations_accumulator = []

            encoded = torch.tensor(encoded, dtype=torch.long).cpu()

            encoded_generation = encoded[~c_conditioning_mask]
            generated_texts_encoded = input_ids.cpu()
            generated_texts_encoded[~c_conditioning_mask] = encoded_generation

            generations_accumulator += [
                tokenizer.decode(g) for g in generated_texts_encoded.tolist()
            ]

            assert len(
                [tokenizer.decode(g) for g in generated_texts_encoded.tolist()]
            ) == input_ids.size(0)

            print(f"Already generated {len(generations_accumulator)} sequences")

            encoded_accumulator += generated_texts_encoded.tolist()
            encoded_generations_accumulator += encoded_generation.view(
                input_ids.size(0), -1
            ).tolist()

            if start_idx == 0:
                input_ids[~c_conditioning_mask] = tokenizer.bos_token_id

            outputs = (generations_accumulator, encoded_accumulator, encoded_generations_accumulator)
            outputs_steps.append(outputs)

            # with open(f"generated_texts.pickle", 'wb') as f:
            #     texts = [output[0] for output in outputs_steps]
            #     pickle.dump(texts, f)
            #     wandb.save(f"generated_texts.pickle")

        outputs_steps_batches.append(outputs_steps)
        batch_i += 1
    outputs_steps = merge_batches(outputs_steps_batches)

    return outputs_steps, metrics


def merge_batches(outputs_steps_batches):
    # outputs_steps_batches - [[(0,0,0), (0, 0, 0), steps], batches]
    def collect_data(i):
        batches = []
        for batch in outputs_steps_batches:
            steps = []
            for step in batch:
                steps.append(step[i])
            batches.append(steps)
        return batches
    
    generations_steps_batches = collect_data(0)        
    encoded_steps_batches = collect_data(1)
    encoded_generations_steps_batches = collect_data(2)

    generations_steps = list(map(lambda x: sum(x, []), zip(*generations_steps_batches)))
    encoded_steps = list(map(lambda x: sum(x, []), zip(*encoded_steps_batches)))
    encoded_generations_steps = list(map(lambda x: sum(x, []), zip(*encoded_generations_steps_batches)))

    return list(zip(generations_steps, encoded_steps, encoded_generations_steps))


