import torch
from torch.nn import functional as F
import numpy as np
from copy import deepcopy
import wandb

from tree_exp.utils import solve_contract, create_wandb_tags_and_config
from tree_exp.env import TreeContractEnv
from tree_exp.replay import ReplayMemoryContract
from networks import QNetworkFC, ClassifierNetworkFC
from policy import PolicyLinProg
from config import DTYPE, DEVICE


class PolicyLinProgTree(PolicyLinProg):
    def __init__(self,
                 env: TreeContractEnv,
                 hid_size=32,
                 n_hid_layers=0,
                 gamma=1,
                 lr_start=1e-2, lr_end=1e-3,
                 eps_start=1, eps_end=0,
                 batch_size=32,
                 n_batches=1000,
                 n_interactions=4,
                 n_warm_start_batches=100,
                 target_update_freq=100,
                 log_freq=100,
                 log_wandb=False,
                 delta=0,
                 scheduling_speed=1.1,
                 verbose=False,
                 outcome_dist_known=True,
                 lr_dist_start=3e-4, lr_dist_end=1e-4,
                 buffer_size=int(1e4),
                 val_optimal=False
                 ):
        super().__init__(env, hid_size, n_hid_layers, gamma, lr_start, lr_end, eps_start, eps_end, batch_size,
                         n_batches, n_interactions, n_warm_start_batches, target_update_freq,
                         log_wandb, delta, scheduling_speed)

        self.inp_size = len(env.state_space.nvec)
        self.out_size = env.action_space.n
        self.outcome_dist_known = outcome_dist_known
        if not self.outcome_dist_known:
            self.lr_dist_start = self.lr_dist = lr_dist_start
            self.lr_dist_end = lr_dist_end
            self.lr_dist_mult = (lr_dist_end / lr_dist_start) ** (scheduling_speed / (n_batches - n_warm_start_batches))

        self.buffer = ReplayMemoryContract(buffer_size)
        self.init_networks()
        self.log_freq = log_freq

        self.loss_dist = 0
        self.verbose = verbose
        self.val_optimal = val_optimal

        self.init_wandb()

    def init_networks(self):
        super().init_networks()
        if not self.outcome_dist_known:
            self.init_dist()

    def init_principal(self):
        self.Q_principal = QNetworkFC(self.inp_size, self.hid_size, self.out_size, self.n_hid_layers)
        self.opt_principal = torch.optim.Adam(self.Q_principal.parameters(), self.lr_start)
        self.Q_principal_target = QNetworkFC(self.inp_size, self.hid_size, self.out_size, self.n_hid_layers)
        self.Q_principal_target.load_state_dict(self.Q_principal.state_dict())
        self.Q_principal_target.eval()

    def init_agent(self):
        self.Q_agent = QNetworkFC(self.inp_size, self.hid_size, self.out_size, self.n_hid_layers)
        self.opt_agent = torch.optim.Adam(self.Q_agent.parameters(), self.lr_start)
        self.Q_agent_target = QNetworkFC(self.inp_size, self.hid_size, self.out_size, self.n_hid_layers)
        self.Q_agent_target.load_state_dict(self.Q_agent.state_dict())
        self.Q_agent_target.eval()

    def init_agent_val(self):
        self.Q_agent_val = QNetworkFC(self.inp_size + self.env.n_outcomes, self.hid_size, self.out_size, self.n_hid_layers)
        self.opt_agent_val = torch.optim.Adam(self.Q_agent_val.parameters(), self.lr_start)
        self.Q_agent_target_val = QNetworkFC(self.inp_size + self.env.n_outcomes, self.hid_size, self.out_size, self.n_hid_layers)
        self.Q_agent_target_val.load_state_dict(self.Q_agent_val.state_dict())
        self.Q_agent_target_val.eval()

    def init_dist(self):
        self.dist_network = ClassifierNetworkFC(self.inp_size, self.hid_size, self.out_size, self.env.n_outcomes,
                                                self.n_hid_layers)
        self.opt_dist = torch.optim.Adam(self.dist_network.parameters(), self.lr_dist)

    def init_wandb(self):
        if self.log_wandb:
            config, tags = create_wandb_tags_and_config(self)
            wandb.init(project="rl_contracts", config=config, tags=tags, reinit=True)

    def process_state(self, state):
        return torch.FloatTensor(state).to(DTYPE).to(DEVICE).unsqueeze(0)

    def get_dists(self, states):
        logits, dists = self.dist_network(states)
        return logits, dists.detach()

    def _get_exp_values(self, dists, values):
        exp_values = (dists * values).sum(-1)
        return exp_values

    @torch.no_grad()
    def get_contracts(self, states, dists, q_values_agent, actions, dones=None):
        if dones is None:
            dones = torch.zeros((states.shape[0],))
        contracts = np.zeros((q_values_agent.shape[0], self.env.n_outcomes))
        states, dists, q_values_agent, actions, dones = \
            states.cpu().numpy(), dists.cpu().numpy(), q_values_agent.cpu().numpy(), actions.cpu().numpy(), dones.cpu().numpy()
        actions = actions.reshape(-1)
        buffer = {}
        for i, (state, dist, q_value_agent, action, done) in \
                enumerate(zip(states, dists, q_values_agent, actions, dones)):
            if not done:
                key = tuple(np.append(state, action))
                if key in buffer:
                    contracts[i] = buffer[key]
                else:
                    res = solve_contract(dist, q_value_agent, action, delta=self.delta)
                    contracts[i] = buffer[key] = res.x
        return torch.from_numpy(contracts).to(DTYPE).to(DEVICE)

    @torch.no_grad()
    def get_action_and_contract_val(self, state, env, done=False):
        if done:
            return torch.zeros((state.shape[0], 1)), torch.zeros((state.shape[0], self.env.n_outcomes), dtype=DTYPE)

        if self.val_optimal:
            q_values_a = torch.from_numpy(env.current_node.q_values_a).unsqueeze(0)
            dists, _, _, action_recommend, _, _, _ = env.get_node_info(state)
        else:
            action_recommend = self.act(state, eps=0)
            q_values_a = self.Q_agent(state)
            if self.outcome_dist_known:
                dists, _, _, _, _, _, _ = env.get_node_info(state)
            else:
                _, dists = self.get_dists(state)

        contract = self.get_contracts(state, dists, q_values_a, action_recommend.unsqueeze(0))

        return action_recommend, contract

    def update_dist(self, loss_dist):
        self.opt_dist.zero_grad()
        loss_dist.backward()
        self.opt_dist.step()

    def schedule(self):
        super().schedule()
        if not self.outcome_dist_known:
            self.lr_dist = max(self.lr_dist * self.lr_dist_mult, self.lr_dist_end)
            for g in self.opt_dist.param_groups:
                g['lr'] = self.lr_dist

    def train(self):
        if not self.val_optimal:
            self._train()
            if self.log_wandb:
                wandb.finish()
        else:
            self._validate()

    def _train(self):
        state, _ = self.env.reset()
        state = self.process_state(state)
        eps = self.eps
        for _ in range(self.n_batches + 1):
            self.iteration += 1

            for _ in range(self.n_interactions):
                action = self.act(state, eps=eps)

                next_state, reward, terminated, truncated, info = self.env.step(action.item())
                done = terminated or truncated
                next_state = self.process_state(next_state)

                self.buffer.push(state.squeeze(0).cpu(), action.cpu(), reward, info['cost'], info['outcome'], done, next_state.squeeze(0).cpu())

                state = next_state
                if done:
                    state, _ = self.env.reset()
                    state = self.process_state(state)

            if self.iteration < self.n_warm_start_batches:
                continue

            transitions = self.buffer.sample(self.batch_size)
            states, actions, rewards, costs, outcomes, dones, next_states = transitions

            q_values_p = self.Q_principal(states)
            q_values_a = self.Q_agent(states)
            with torch.no_grad():
                next_q_values_p = self.Q_principal_target(next_states)
                next_q_values_a = self.Q_agent_target(next_states)
                next_actions = self.get_best_actions(self.Q_principal(next_states))  # double DQN

            if self.outcome_dist_known:
                dists, _, _, _, _, _, _ = self.env.get_node_info(states)
                next_dists, _, _, _, _, _, _ = self.env.get_node_info(next_states)
            else:
                logits, dists = self.get_dists(states)
                with torch.no_grad():
                    _, next_dists = self.get_dists(next_states)

            contracts = self.get_contracts(states, dists, q_values_a, actions)
            next_contracts = self.get_contracts(next_states, next_dists, next_q_values_a, next_actions, dones)

            q_values_p = q_values_p.gather(-1, actions).squeeze(-1)
            q_values_a = q_values_a.gather(-1, actions).squeeze(-1)

            next_dists = next_dists.gather(1, next_actions.unsqueeze(-1).expand(-1, -1, self.env.n_outcomes)).squeeze(1)
            next_q_values_p = next_q_values_p.gather(-1, next_actions).squeeze(-1)
            next_q_values_a = next_q_values_a.gather(-1, next_actions).squeeze(-1)
            next_exp_payments = self._get_exp_values(next_dists, next_contracts)

            with torch.no_grad():
                payments = contracts.gather(-1, outcomes).squeeze(-1)
                targets_p = rewards - payments + self.gamma * (1 - dones) * next_q_values_p

                next_q_values_a = next_exp_payments + next_q_values_a
                targets_a = -costs + self.gamma * (1 - dones) * next_q_values_a

            loss_p = self.get_loss(q_values_p, targets_p)
            self.update_principal(loss_p)
            self.loss_principal = self.mixing_coef * self.loss_principal + (1 - self.mixing_coef) * loss_p.item()

            loss_a = self.get_loss(q_values_a, targets_a)
            self.update_agent(loss_a)
            self.loss_agent = self.mixing_coef * self.loss_agent + (1 - self.mixing_coef) * loss_a.item()

            if not self.outcome_dist_known:
                logits = logits.gather(1, actions.unsqueeze(-1).expand(-1, -1, self.env.n_outcomes)).squeeze(1)
                loss_dist = F.cross_entropy(logits, outcomes.squeeze(-1))
                self.update_dist(loss_dist)
                self.loss_dist = self.mixing_coef * self.loss_dist + (1 - self.mixing_coef) * loss_dist.item()

            self.schedule()

            if self.iteration % self.target_update_freq == 0:
                self.Q_principal_target.load_state_dict(self.Q_principal.state_dict())
                self.Q_agent_target.load_state_dict(self.Q_agent.state_dict())

            if self.iteration % self.log_freq == 0:
                self.log()

    def _validate(self):
        state, _ = self.env.reset()
        state = self.process_state(state)
        _, contract = self.get_action_and_contract_val(state, self.env)
        for _ in range(self.n_batches + 1):
            self.iteration += 1

            for _ in range(self.n_interactions):
                action = self.act_val(state, contract, eps=self.eps)

                next_state, _, terminated, truncated, info = self.env.step(action.item())
                done = terminated or truncated
                next_state = self.process_state(next_state)
                _, next_contract = self.get_action_and_contract_val(next_state, self.env, done)

                payment = contract[0, info['outcome']].item()
                reward = payment - info['cost']

                self.buffer.push(state.squeeze(0).cpu(), action.cpu(), reward, contract.cpu(), done, next_state.squeeze(0).cpu(), next_contract.cpu())

                state, contract = next_state, next_contract
                if done:
                    state, _ = self.env.reset()
                    state = self.process_state(state)
                    _, contract = self.get_action_and_contract_val(state, self.env)

            if self.iteration < self.n_warm_start_batches:
                continue

            transitions = self.buffer.sample(self.batch_size)
            states, actions, rewards, contracts, dones, next_states, next_contracts = transitions
            states_contracts = self._cat_states_contracts(states, contracts)
            next_states_contracts = self._cat_states_contracts(next_states, next_contracts)

            q_values = self.Q_agent_val(states_contracts)
            q_values = q_values.gather(-1, actions).squeeze(-1)
            with torch.no_grad():
                next_q_values = self.Q_agent_target_val(next_states_contracts)
                next_actions = self.get_best_actions(self.Q_agent_val(next_states_contracts))  # double DQN
                next_q_values = next_q_values.gather(-1, next_actions).squeeze(-1)
                targets = rewards + self.gamma * (1 - dones) * next_q_values

            loss = self.get_loss(q_values, targets)
            self.update_agent_val(loss)
            self.loss_agent_val = self.mixing_coef * self.loss_agent_val + (1 - self.mixing_coef) * loss.item()

            self.schedule()

            if self.iteration % self.target_update_freq == 0:
                self.Q_agent_target_val.load_state_dict(self.Q_agent_val.state_dict())

            if self.iteration % self.log_freq == 0:
                self.log_val()

    @torch.no_grad()
    def sample_episodes(self, n_episodes=1000):
        env = deepcopy(self.env)
        return_p = return_a = 0
        for ep in range(n_episodes):
            gamma = 1
            state, _ = env.reset()
            done = False
            while not done:
                state = self.process_state(state)

                action = self.act(state, eps=0)
                q_values_a = self.Q_agent(state)

                if self.outcome_dist_known:
                    dists, _, _, _, _, _, _ = self.env.get_node_info(state)
                else:
                    _, dists = self.get_dists(state)

                contracts = self.get_contracts(state, dists, q_values_a, action.unsqueeze(0))

                state, reward, terminated, truncated, info = env.step(action.item())
                done = terminated or truncated

                cost = info['cost']
                payment = contracts[0, info['outcome']].item()

                return_p += (reward - payment) * gamma
                return_a += (payment - cost) * gamma
                gamma *= self.gamma

        return_p /= n_episodes
        return_a /= n_episodes
        return return_p, return_a

    @torch.no_grad()
    def sample_episodes_val(self, n_episodes=1000):
        env = deepcopy(self.env)
        return_p = return_a = accuracy = n_steps = 0
        for ep in range(n_episodes):
            gamma = 1
            state, _ = env.reset()
            done = False
            while not done:
                n_steps += 1
                state = self.process_state(state)

                action_recommend, contract = self.get_action_and_contract_val(state, env)
                action = self.act_val(state, contract, eps=0)

                state, reward, terminated, truncated, info = env.step(action.item())
                done = terminated or truncated

                cost = info['cost']
                payment = contract[0, info['outcome']].item()

                return_p += (reward - payment) * gamma
                return_a += (payment - cost) * gamma
                accuracy += int(action_recommend.item() == action.item())
                gamma *= self.gamma

        return_p /= n_episodes
        return_a /= n_episodes
        accuracy /= n_steps
        return return_p, return_a, accuracy

    @torch.no_grad()
    def eval_utilities(self, root):
        q_values_p = np.zeros((root.n_outcomes,), dtype=np.float32)
        q_values_a = np.zeros((root.n_outcomes,), dtype=np.float32)
        q_values_p_corrected = np.zeros((root.n_outcomes,), dtype=np.float32)
        q_values_a_corrected = np.zeros((root.n_outcomes,), dtype=np.float32)
        n_opt_actions = 0

        if root.children is not None:
            for i, child in enumerate(root.children):
                if child is not None:
                    q_values_p[i], q_values_a[i], q_values_p_corrected[i], q_values_a_corrected[i], opt_ac = self.eval_utilities(child)
                    n_opt_actions += opt_ac
        q_values_p = self.gamma * (root.F * q_values_p.reshape(1, -1)).sum(-1)
        q_values_a = self.gamma * (root.F * q_values_a.reshape(1, -1)).sum(-1) - root.costs
        q_values_p_corrected = self.gamma * (root.F * q_values_p_corrected.reshape(1, -1)).sum(-1)
        q_values_a_corrected = self.gamma * (root.F * q_values_a_corrected.reshape(1, -1)).sum(-1) - root.costs

        state = self.process_state(root.state)
        action = self.act(state, eps=0)
        q_values_a_pred = self.Q_agent(state)
        if self.outcome_dist_known:
            dists, _, _, _, _, _, _ = self.env.get_node_info(state)
        else:
            _, dists = self.get_dists(state)
        dists = dists.squeeze(0).cpu().numpy()
        q_values_a_pred = q_values_a_pred.squeeze(0).cpu().numpy()
        action = action.item()

        contract = solve_contract(dists, q_values_a_pred, action, delta=self.delta).x
        contract_corrected = solve_contract(root.F, q_values_a_corrected, action, delta=self.delta).x

        utility_p = (root.F[action] * (root.rewards - contract)).sum() + q_values_p[action]
        utility_a = (root.F[action] * contract).sum() + q_values_a[action]
        utility_p_corrected = (root.F[action] * (root.rewards - contract_corrected)).sum() + q_values_p_corrected[action]
        utility_a_corrected = (root.F[action] * contract_corrected).sum() + q_values_a_corrected[action]
        return utility_p, utility_a, utility_p_corrected, utility_a_corrected, n_opt_actions + int(action == root.opt_action)

    @torch.no_grad()
    def log(self):
        # return_p, return_a = self.sample_episodes()
        utility_p, utility_a, utility_p_corrected, utility_a_corrected, n_opt_actions = self.eval_utilities(self.env.root)
        accuracy = n_opt_actions / self.env.n_states

        if self.verbose:
            states = self.env.get_all_states()
        else:
            states = self.env.get_root_state()
        q_values_p = self.Q_principal(states)
        q_values_a = self.Q_agent(states)
        true_dists, costs, _, true_actions, true_contracts, true_utility_p, true_utility_a = self.env.get_node_info(states)
        if self.outcome_dist_known:
            dists = true_dists
        else:
            _, dists = self.get_dists(states)
        actions = self.get_best_actions(q_values_p)
        contracts = self.get_contracts(states, dists, q_values_a, actions)

        q_values_p = q_values_p.gather(-1, actions).squeeze(-1)
        q_values_a = q_values_a.gather(-1, actions).squeeze(-1)
        dists_action = dists.gather(1, actions.unsqueeze(-1).expand(-1, -1, self.env.n_outcomes)).squeeze(1)
        exp_payments = self._get_exp_values(dists_action, contracts)
        q_values_a = (exp_payments + q_values_a)

        print(f'\nIteration: {self.iteration}')
        print(f'Losses principal: {round(self.loss_principal, 3)}, '
              f'agent: {round(self.loss_agent, 3)}, '
              f'dist: {round(self.loss_dist, 3)}; '
              f'Utility principal: {round(utility_p, 3)}, '
              f'agent: {round(utility_a, 3)}; '
              f'Corrected Utility principal: {round(utility_p_corrected, 3)}, '
              f'agent: {round(utility_a_corrected, 3)}; '
              f'Accuracy: {round(accuracy, 3)}')

        root_state = True
        states, actions, contracts, q_values_a, q_values_p, dists = \
            (states.cpu().numpy(), actions.cpu().numpy(), contracts.cpu().numpy(), q_values_a.cpu().numpy().astype(np.float64),
             q_values_p.cpu().numpy().astype(np.float64), dists.cpu().numpy().astype(np.float64))
        true_dists, true_actions, true_contracts, true_utility_p, true_utility_a = \
            (true_dists.cpu().numpy().astype(np.float64), true_actions.cpu().numpy(), true_contracts.cpu().numpy(),
             true_utility_p.cpu().numpy().astype(np.float64), true_utility_a.cpu().numpy().astype(np.float64))
        for s, a, c, q_a, q_p, d, t_a, t_c, t_u_a, t_u_p, t_d in \
                zip(states, actions, contracts, q_values_a, q_values_p, dists, true_actions, true_contracts, true_utility_a, true_utility_p, true_dists):
            print(f'state: {s} '
                  f'OPTIMAL: '
                  f'action {t_a.item()}, '
                  f'contract {t_c.round(3)}, '
                  f'Q-value P {t_u_p.round(3)}, '
                  f'Q-value A {t_u_a.round(3)}, '
                  f'dists {t_d.round(2).tolist()}')
            print(f'state: {s} '
                  f'PREDICT: '
                  f'action {a.item()}, '
                  f'contract {c.round(3)}, '
                  f'Q-value P {q_p.round(3)}, '
                  f'Q-value A {q_a.round(3)}, '
                  f'dists {d.round(2).tolist()}')

            if self.log_wandb and root_state:
                wandb.log(
                    {'train': {
                        'loss_p': self.loss_principal,
                        'loss_a': self.loss_agent,
                        'utility_p': utility_p,
                        'utility_a': utility_a,
                        'utility_p_corr': utility_p_corrected,
                        'utility_a_corr': utility_a_corrected,
                        'utility_p_opt': t_u_p,
                        'utility_a_opt': t_u_a,
                        'utility_p_pred': q_p,
                        'utility_a_pred': q_a,
                        'accuracy': accuracy,
                    }},
                    step=self.iteration,
                    commit=True
                )
                root_state = False

    @torch.no_grad()
    def log_val(self):

        return_p, return_a, accuracy = self.sample_episodes_val()

        print(f'\nIteration: {self.iteration}')
        print(f'Loss agent: {round(self.loss_agent_val, 3)}, '
              f'Return principal: {round(return_p, 3)}, '
              f'Return agent: {round(return_a, 3)}; '
              f'Accuracy: {round(accuracy, 3)}')

        if self.log_wandb:
            wandb.log(
                {'val': {
                    'loss': self.loss_agent_val,
                    'return_p': return_p,
                    'return_a': return_a,
                    'accuracy': accuracy,
                }},
                step=self.iteration,
                commit=True
            )
