from collections import defaultdict
import numpy as np
import torch
import torch.nn.functional as F
import wandb
from contextlib import nullcontext


def get_random_P(order, batch_size, generator, device, dtype, interval=[0, 1]):
    assert len(interval) == 2
    assert interval[0] < interval[1]

    pk = interval[0] + (interval[1]-interval[0]) * torch.rand((batch_size, 2**order, 1), generator=generator, dtype=dtype, device=device)
    P = torch.cat([1 - pk, pk], dim=2)

    return P

def empirical_est(x, y, order, beta=1, save_counts=False):
    assert x.size(0) == 1
    assert beta > 0

    seq_length = x.size(1)
    device = x.device
    x = x.float().squeeze()
    y = y.float().squeeze()
    powers = torch.Tensor([2**i for i in reversed(range(order))]).to(device)
    idx = F.conv1d(x.view(1, -1), powers.view(1, 1, -1)).squeeze()
    est_vec = []
    counts = []
    totals = []
    for i in range(2**order):
        mask = (idx == i)
        s = y[order-1:][mask][:-1]
        count = s.cumsum(0)
        count = F.pad(count, (1, 0))
        total = torch.arange(len(s)+1, device=device)
        p = (count + beta) / (total + 2*beta)
        est_vec.append(p)
        counts.append(count)
        totals.append(total)
    if save_counts:
        return est_vec, counts, totals
    else:
        return est_vec

def optimal_est(P, order, sequence_length, generator, extra_args):
    x, y = get_batch(P, order, sequence_length, 4096, generator, extra_args)
    powers = torch.Tensor([2**i for i in reversed(range(order))]).to(P.device)
    opt_logits = torch.zeros(x.size(0), x.size(1), P.size(1), device=P.device)
    if order > 1:
        opt_logits[:,:order-1,:] = 0.5*torch.ones(x.size(0), order-1, P.size(1), device=P.device)
    for i in range(order-1, sequence_length):
        idx = x[:,i-order+1:i+1].float() @ powers
        opt_logits[:,i,:] = P[idx.to(int)]
    opt_logits = torch.log(opt_logits)
    opt_loss = F.nll_loss(opt_logits.view(-1, opt_logits.size(-1)), y.view(-1), ignore_index=-1)

    return opt_loss

# Optimized Markov data generation
def get_batch(P, order, seq_length, batch_size, generator, extra_args, interval=[0,1]):
    data = torch.zeros(batch_size, seq_length+1, device=extra_args.device)
    if P == None:
        # Generate first k bits
        alpha = 0.5
        data[:, :order] = torch.bernoulli(alpha * torch.ones((batch_size, order), device=extra_args.device), generator=generator)
        # Generate following bits
        data[:, order:] = get_batch_from_past(data[:, :order], None, order, seq_length-order+1, batch_size, generator, extra_args.device, extra_args.dtype, interval=interval)
    else:
        # Use same fixed P for all sequences
        # Generate first k bits
        alpha=0.5
        data[:, :order] = torch.bernoulli(alpha * torch.ones((batch_size, order), device=extra_args.device), generator=generator)
        # Generate following bits
        data[:, order:] = get_batch_from_past(data[:, :order], P, order, seq_length-order+1, batch_size, generator, extra_args.device, extra_args.dtype)
    x = data[:,:seq_length].to(int)
    y = data[:,1:].to(int)
    
    return x, y

def get_batch_from_past(past, P, order, seq_length, batch_size, generator, device, dtype, interval=[0,1]):
    if P is None:
        P = get_random_P(order, batch_size, generator, device, dtype, interval=interval)
    else:
        P = P.unsqueeze(0).repeat(batch_size, 1, 1)
    data = torch.zeros(batch_size, order+seq_length, device=device)
    data[:,:order] = past[:,-order:]
    batch_indices = torch.arange(batch_size)
    powers = torch.Tensor([2**i for i in reversed(range(order))]).to(device)
    for i in range(order, seq_length):
        # Extract the previous 'order' symbols for the entire batch
        prev_symbols = data[:, i-order:i]
        # Compute indices using the dot product with powers of 2
        idx = (prev_symbols @ powers).int()
        # Fetch next symbols from the transition matrix P for each batch in parallel
        next_symbols = torch.multinomial(P[batch_indices, idx], 1, generator=generator).squeeze(1)
        data[:, i] = next_symbols

    return data[:,order:]


@torch.no_grad()
def eval(model, P, order, sequence_length, batch_size, generator, extra_args, max_num_batches=24, ctx=nullcontext()):
    assert model.training == False
    assert P is not None

    loss_list_val, acc_list = [], []

    for _ in range(max_num_batches):
        x, y = get_batch(P, order, sequence_length, batch_size, generator, extra_args)
        with ctx:
            outputs = model(x, targets=y)
        val_loss = outputs['loss']
        loss_list_val.append(val_loss)
        acc_list.append((outputs['logits'].argmax(-1) == y).float().mean())

    val_acc = torch.stack(acc_list).mean().item()
    val_loss = torch.stack(loss_list_val).mean().item()
    val_perplexity = 2.71828 ** val_loss

    return val_acc, val_loss, val_perplexity


@torch.no_grad()
def eval_probs(model, P, order, sequence_length, generator, extra_args, betas = None, input_seq=None, output_seq=None, ctx=nullcontext()):
    assert model.training == False
    assert P is not None
    if betas is None:
        betas = [1]
    
    if input_seq is not None and output_seq is not None:
        x = input_seq[:, :sequence_length]
        y = output_seq[:, :sequence_length]
    else:
        x, y = get_batch(P, order, sequence_length, 1, generator, extra_args)

    # Get model estimation
    with ctx:
        outputs = model(x, targets=y, save_weights=True)
    probs = F.softmax(outputs['logits'], dim=-1)
    xb = x[0].float()
    probsb = probs[0, order-1:]
    powers = torch.Tensor([2**i for i in reversed(range(order))]).to(extra_args.device)
    idx = F.conv1d(xb.view(1, -1), powers.view(1, 1, -1)).squeeze()
    prob_vec = []
    for i in range(2**order):
        vec = probsb[idx == i][:,1] # estimated p
        prob_vec.append(vec)

    # Get empirical add-beta estimator
    est_vec = empirical_est(x, y, order)
    beta_vec = []
    for beta in betas:
        beta_est = empirical_est(x, y, order, beta=beta)
        err = 0
        for i in range(2**order):
            err += torch.linalg.norm(prob_vec[i] - beta_est[i], ord=1)
        beta_vec.append(err)
    
    return prob_vec, est_vec, beta_vec

@torch.no_grad()
def eval_conditions(model, extra_args, ctx=nullcontext()):
    assert model.training == False

    x0 = torch.Tensor([[0,0,1,1,0]])
    x1 = torch.zeros(1,251)
    x = torch.cat((x0, x1), dim=1).to(int).to(extra_args.device)
    with ctx:
        outputs = model(x, targets=x, check_conditions=True)

    return None


def save_checkpoint(model, opt, scheduler, itr, ckpt_path, **extra_args):

    checkpoint = dict({
        'model': model.state_dict(),
        'optimizer': opt.state_dict(),
        'scheduler': scheduler.state_dict(),
        'itr': itr,
    }, **extra_args)

    torch.save(checkpoint, ckpt_path)
