from typing import Any, Union
from agents.base import AbstractAgent

import torch
import numpy as np

from utils.env_utils import preprocess_lending_obs, preprocess_college_admission_obs

from fair_gym import AcceptRejectAction, LendingEnv, CollegeAdmissionEnv
from torch.utils.data import DataLoader, TensorDataset


class CollegeClassifier(AbstractAgent):
    """
    A binary classifier for the college admission environment.
    """

    def __init__(
        self,
        state_dim: int,
        env: Union[LendingEnv, CollegeAdmissionEnv],
        hidden_width: int = 256,
        rollout_size: int = 10000,
        n_epochs: int = 10,
        device: torch.device = torch.device("cpu"),
    ):
        self.env = env
        self.rollout_size = rollout_size
        self.n_epochs = n_epochs
        self.device = device

        self.model = torch.nn.Sequential(
            torch.nn.Linear(state_dim, hidden_width),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_width, hidden_width),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_width, 1),
            torch.nn.Sigmoid(),
        ).to(self.device)

        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)

    def act(self, state: Union[torch.Tensor, np.ndarray]) -> np.ndarray:
        state = torch.Tensor(state).to(self.device)
        prob = self.model(state).detach().cpu().numpy()
        if prob > 0.5:
            action = AcceptRejectAction.ACCEPT.value
        else:
            action = AcceptRejectAction.REJECT.value
        return action

    def collect_data(self) -> tuple[torch.Tensor, torch.Tensor]:
        states, rewards = [], []
        if isinstance(self.env.unwrapped, LendingEnv):
            preprocessor_fn = preprocess_lending_obs
        elif isinstance(self.env.unwrapped, CollegeAdmissionEnv):
            preprocessor_fn = preprocess_college_admission_obs
        else:
            raise ValueError("Environment not supported.")

        obs, _ = self.env.reset()
        done = False
        state, _, _, _ = preprocessor_fn(obs)
        state = state.to(self.device)

        for _ in range(self.rollout_size):
            # Accept all applicants
            a = AcceptRejectAction.ACCEPT.value
            next_state, reward, terminated, truncated, _ = self.env.step(a)
            done = terminated or truncated

            next_state, _, _, _ = preprocessor_fn(next_state)

            states.append(state)
            rewards.append(reward)

            state = torch.Tensor(next_state).to(self.device)

            if done:
                obs, _ = self.env.reset()
                state, _, _, _ = preprocessor_fn(obs)
                state = state.to(self.device)

        return torch.stack(states), torch.Tensor(rewards)

    def train(self):
        # Collect the data
        states, rewards = self.collect_data()
        states = states.to(self.device)
        rewards = rewards.to(self.device)

        # Change -1 rewards to 0
        labels = (rewards + 1) / 2

        # Create TensorDataset and DataLoader
        dataset = TensorDataset(states, labels)
        dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

        # Set the model to training mode
        self.model.train()

        loss_fn = torch.nn.BCELoss()

        # Training loop
        for epoch in range(self.n_epochs):
            epoch_loss = 0.0
            for batch_states, batch_labels in dataloader:
                batch_states = batch_states.to(self.device)
                batch_labels = batch_labels.to(self.device)
                
                self.optimizer.zero_grad()
                logits = self.model(batch_states)
                loss = loss_fn(logits, batch_labels.unsqueeze(1))
                loss.backward()
                self.optimizer.step()

                epoch_loss += loss.item()

            # Print the average loss for the epoch
            print(
                f"Epoch {epoch+1}/{self.n_epochs}, Loss: {epoch_loss/len(dataloader):.4f}"
            )
