import torch


class PolicyLinProg:
    """
    Base class to inherit from
    """
    network_class = None
    mixing_coef = .99

    def __init__(self,
                 env,
                 hid_size=32,
                 n_hid_layers=0,
                 gamma=1,
                 lr_start=1e-2, lr_end=1e-3,
                 eps_start=1, eps_end=0,
                 batch_size=32,
                 n_batches=1000,
                 n_interactions=4,
                 n_warm_start_batches=100,
                 target_update_freq=100,
                 log_wandb=False,
                 delta=0,
                 scheduling_speed=1.,
                 ):
        self.env = env
        self.inp_size = None
        self.hid_size = hid_size
        self.out_size = None
        self.n_hid_layers = n_hid_layers
        self.gamma = gamma
        self.lr_start = self.lr = lr_start
        self.lr_end = lr_end
        self.lr_mult = (lr_end / lr_start) ** (scheduling_speed / (n_batches - n_warm_start_batches))
        self.eps_start = self.eps = eps_start
        self.eps_end = eps_end
        self.eps_deduct = scheduling_speed * (eps_start - eps_end) / (n_batches - n_warm_start_batches)
        self.batch_size = batch_size
        self.n_batches = n_batches
        self.n_interactions = n_interactions
        self.n_warm_start_batches = n_warm_start_batches
        self.target_update_freq = target_update_freq
        self.iteration = 0

        self.loss_principal = 0
        self.loss_agent = 0
        self.loss_agent_val = 0

        self.delta = delta

        self.log_wandb = log_wandb

    def init_networks(self):
        self.init_principal()
        self.init_agent()
        self.init_agent_val()

    def init_principal(self):
        self.Q_principal = None
        self.Q_principal_target = None
        self.opt_principal = None

    def init_agent(self):
        self.Q_agent = None
        self.Q_agent_target = None
        self.opt_agent = None

    def init_agent_val(self):
        self.Q_agent_val = None
        self.Q_agent_val_target = None
        self.opt_agent_val = None

    def init_wandb(self):
        raise NotImplementedError

    def process_state(self, state):
        return torch.FloatTensor(state)

    def act(self, states, eps=0):
        return self.Q_principal(states, sample=True, eps=eps)[1]

    def act_val(self, states, contracts, eps=0):
        return self.Q_agent_val(self._cat_states_contracts(states, contracts), sample=True, eps=eps)[1]

    def _cat_states_contracts(self, states, contracts):
        return torch.cat([states, contracts], dim=-1)

    @torch.no_grad()
    def get_best_actions(self, q_values):
        actions = q_values.argmax(-1, keepdims=True)
        return actions

    def get_loss(self, q_values, targets):
        return torch.nn.functional.huber_loss(q_values, targets)

    def update_principal(self, loss):
        self.opt_principal.zero_grad()
        loss.backward()
        self.opt_principal.step()

    def update_agent(self, loss):
        self.opt_agent.zero_grad()
        loss.backward()
        self.opt_agent.step()

    def update_agent_val(self, loss):
        self.opt_agent_val.zero_grad()
        loss.backward()
        self.opt_agent_val.step()

    def schedule(self):
        self.eps = max(self.eps - self.eps_deduct, self.eps_end)
        self.lr = max(self.lr * self.lr_mult, self.lr_end)
        for g in self.opt_principal.param_groups:
            g['lr'] = self.lr
        for g in self.opt_agent.param_groups:
            g['lr'] = self.lr

    def reset_schedules(self):
        self.lr = self.lr_start
        self.eps = self.eps_start

    def train(self):
        raise NotImplementedError

    @torch.no_grad()
    def log(self):
        raise NotImplementedError
