import numpy as np
import os

import gym
import numpy as np
import numpy.typing as npt

from abc import ABC, abstractmethod
from tqdm import tqdm
from typing import List


class Colonel(ABC):

    strategy: npt.ArrayLike  # Strategy is a prob distribution

    @abstractmethod
    def act(self, state: npt.ArrayLike):
        raise NotImplementedError

class SimpleColonel(Colonel):
    """
    Set strategy that doesnt change
    """
    def __init__(self, strategy: npt.ArrayLike, n: int, k: int):
        assert strategy.shape[0] == k
        self.k = k
        self.n = n
        self.strategy = strategy

    def act(self, state: npt.ArrayLike):
        return self.strategy
    
class BlottoGame(gym.Env):
    """
    Base class for a blotto game
    Blotto is a simultaneous move game
    """
    def __init__(self, K: int, N: int, players: int = 2):
        assert players == 2, "Currently only supports 2 players"
        self.K = K
        self.N = N
        self.players = players

    def close(self):
        pass

    def reset(self):
        pass

    def render(self):
        pass

    def seed(self):
        pass

    def step(self, actions: List[List[int]]):
        """
        Actions
        """
        assert self.validate_actions(actions), "Invalid action"

        wins = np.zeros(self.players)

        for K_i in zip(*actions):
            if np.array(K_i).std() == 0: # Tie condition
                continue

            wins[np.argmax(K_i)] += 1

        return wins

    def state(self):
        return np.zeros(self.K) # Returning dummy state for now

    def validate_actions(self, actions: List[List[int]]):
        assert len(actions) == self.players, "The number of actions passed should equal the number of players"
        for p_i in actions:
            assert len(p_i) == self.K, "Players should pass an integer for each zone"
            assert sum(p_i) == self.N, f"Players have to assign all troops. Only assigned {sum(p_i)} troops"
        return True

class WeightedBlotto(BlottoGame):
    def __init__(self, K: int, N: int, payoff: List, players: int = 2):
        super().__init__(K, N, players=players)

        self.payoff = payoff
    
    def step(self, actions: List[List[int]]) -> int:
        """
        Actions
        """
        assert self.validate_actions(actions), "Invalid action"

        wins = np.zeros(self.players)
 
        for i,K_i in enumerate(zip(*actions)):
            payoff = self.payoff[i]
            if np.array(K_i).std() == 0: # If tie then leave the assignment 0 
                continue
                
            wins[np.argmax(K_i)] += payoff

        return wins

class Tourney:

    agents: List[Colonel]

    @abstractmethod
    def run(self):
        """
        Function to run a tournament
        """
        raise NotImplementedError

class RoundRobin(Tourney):

    def __init__(self, game: BlottoGame, agents: List[Colonel], rounds: int = 1):
        self.agents = agents
        self.game = game
        self.rounds = rounds

    def play_match(self, agent1: Colonel, agent2: Colonel ,rounds: int = 1) -> dict:

        r = {0:0,1:0}

        state = self.game.state()
        for _ in range(rounds):
            a_1, a_2 = agent1.act(state), agent2.act(state)
            p1,p2 = tuple(self.game.step([a_1,a_2]))
            r[0] += p1
            r[1] += p2
            state = self.game.state()

        return r

    def run(self) -> npt.ArrayLike:
        """
        Runs round robin tourney. Plays all i,j pairs
        assumes that j,i results can be inferred from i,j results
        So only runs N**2 / 2 matches
        """

        N = len(self.agents)
        results = np.zeros((N,N))

        for i in tqdm(range(N)):
            a_i = self.agents[i]
            for j in range(i,N):
                if i == j: continue
                a_j = self.agents[j]
                r_ij = self.play_match(a_i,a_j,self.rounds)
                results[i][j] += r_ij[0]
                results[j][i] += r_ij[1]

        return results
   
    
def make_strat(alphas,s,n,seed):
    """[summary]
    """
    np.random.seed(seed)
    strats = np.random.dirichlet(alphas,n)
    strats *= s # Scale samples to size N
    strats = np.rint(strats) # Fix to integers
    adjust = s - np.sum(strats,axis=1) # Which strategies are > S? Will try and fix
    for i in range(n):
        strats[i,0] += adjust[i] # This can cause negative allocation so just drop 
        np.random.shuffle(strats[i]) # Try to avoid biasing 
    return strats

def create_blotto_data(K: int, N: int, max_agents: int = 30000, payoff= [1,2,3,4], seed: int = 42, game_type: str = "normal"):

    game = BlottoGame(K=K,N=N)
    if game_type == "weighted":
        game = WeightedBlotto(K=K, N = N, payoff=payoff)

    strats = make_strat([1]*K,s=N,n=150000, seed = seed) # Currently just sample a really high number to ensure entire space is searched

    strats = np.unique(strats,axis=0) # Drop dups

    strats = np.delete(strats,np.where(strats < 0)[0],0) # Drop strats with negative integers
    
    agents = []
    for strat in strats[:max_agents]:
        agents.append(SimpleColonel(strat,n=N,k=K))

    tourney = RoundRobin(game,agents)
    
    A = tourney.run()
    
    Pwin = np.zeros(A.shape)
    for i in range(A.shape[0]):
        for j in range(i,A.shape[0]):
            P_ij = A[i][j]
            P_ji = A[j][i]
            if P_ij > P_ji:
                Pwin[i][j] = 1
                Pwin[j][i] = 0
            elif P_ij < P_ji:
                Pwin[i][j] = 0
                Pwin[j][i] = 1
            else:
                Pwin[i][j] = 0.5
                Pwin[j][i] = 0.5
    
    return A, Pwin, strats[:max_agents]

def create_blotto_games(save_path, ks, payouts, n_s: list = []):

    if len(n_s) == 0:
        n_s = [i for i in range(5,46)]

    for K, payout in zip(ks,payouts):
        for N in n_s:
            print(f"K={K} N={N}")
            dir = save_path + f"blotto_{K}_{payout}_{N}"
            F_path, P_path, X_path = dir + "/" + "F.npy", dir + "/" + "P.npy", dir + "/" + "X.npy"
            
            if os.path.isfile(F_path) and os.path.isfile(P_path) and os.path.isfile(X_path):
                continue

            A, P, S = create_blotto_data(K,N, payoff=payout, game_type="weighted")

            if not os.path.isdir(dir):
                os.mkdir(dir)

            F = P - 0.5
            np.save(dir + "/" + "F.npy", F)
            np.save(dir + "/" + "P.npy", P)
            np.save(dir + "/" + "X.npy", S)
    