import numpy as np
from bt_utils import build_pair_arrays
from bt_mle import mle_bt, NEWTON_TOL, NEWTON_MAX_ITERS


class BTState:
    def __init__(self, n, min_round_robins=1):
        self.n = n
        self.min_round_robins = min_round_robins
        self.pair_i, self.pair_j, self.pair_to_idx = build_pair_arrays(n)
        self.num_pairs = len(self.pair_i)
        self.N = np.zeros(self.num_pairs) 
        self.S = np.zeros(self.num_pairs) 
        self.wins = np.zeros(n, dtype=np.int32)
        self.losses = np.zeros(n, dtype=np.int32)

        self.warmup_complete = False
        self.warmup_round = 0
        self.warmup_pair_idx = 0

        self._theta_cache = None
        self._dirty = True

    def warmup_next_pair(self):
        if self.warmup_complete:
            return None
        return (int(self.pair_i[self.warmup_pair_idx]), int(self.pair_j[self.warmup_pair_idx]))

    def warmup_observe(self, i, j, y):
        self.add_observation(i, j, y)
        self.warmup_pair_idx += 1

        if self.warmup_pair_idx >= self.num_pairs:
            self.warmup_round += 1
            self.warmup_pair_idx = 0

            if self.warmup_round >= self.min_round_robins and self.mle_exists():
                self.warmup_complete = True
                return False
        return True

    def add_observation(self, i, j, y):
        if i > j:
            i, j, y = j, i, 1 - y
        idx = self.pair_to_idx[(i, j)]
        self.N[idx] += 1
        self.S[idx] += y

        if y == 1:
            self.wins[i] += 1
            self.losses[j] += 1
        else:
            self.wins[j] += 1
            self.losses[i] += 1

        self._dirty = True

    def mle_exists(self):
        return np.all(self.wins >= 1) and np.all(self.losses >= 1)

    def get_mle(self, tol=NEWTON_TOL, max_iters=NEWTON_MAX_ITERS):
        if not self._dirty and self._theta_cache is not None:
            return self._theta_cache

        self._theta_cache = mle_bt(
            self.n, self.N, self.S, self.pair_i, self.pair_j,
            theta_init=self._theta_cache, tol=tol, max_iters=max_iters
        )
        self._dirty = False
        return self._theta_cache

    def total_comparisons(self):
        return int(self.N.sum())
