import sys
from copy import deepcopy
from pathlib import Path

import numpy as np
import torch

from agent.memory.DreamerMemory import DreamerMemory
from agent.models.DreamerModel import DreamerModel
from agent.optim.loss import model_loss, actor_loss, value_loss, actor_rollout
from agent.optim.utils import advantage
from environments import Env
from networks.dreamer.action import Actor
from networks.dreamer.critic import Critic


def orthogonal_init(tensor, gain=1):
    if tensor.ndimension() < 2:
        raise ValueError("Only tensors with 2 or more dimensions are supported")

    rows = tensor.size(0)
    cols = tensor[0].numel()
    flattened = tensor.new(rows, cols).normal_(0, 1)

    if rows < cols:
        flattened.t_()

    # Compute the qr factorization
    u, s, v = torch.svd(flattened, some=True)
    if rows < cols:
        u.t_()
    q = u if tuple(u.shape) == (rows, cols) else v
    with torch.no_grad():
        tensor.view_as(q).copy_(q)
        tensor.mul_(gain)
    return tensor


def initialize_weights(mod, scale=1.0, mode='ortho'):
    for p in mod.parameters():
        if mode == 'ortho':
            if len(p.data.shape) >= 2:
                orthogonal_init(p.data, gain=scale)
        elif mode == 'xavier':
            if len(p.data.shape) >= 2:
                torch.nn.init.xavier_uniform_(p.data)


class DreamerLearner:

    def __init__(self, config, env_config):
        self.config = config
        self.env = env_config.create_env()
        self.n_agents = self.env.n_agents
        self.models = [DreamerModel(config).to(config.DEVICE).eval() for _ in range(self.n_agents)]
        self.actors = [Actor(config.FEAT, config.ACTION_SIZE, config.ACTION_HIDDEN, config.ACTION_LAYERS).to(config.DEVICE) for _ in range(self.n_agents)]
        self.critics = [Critic(config.FEAT, config.HIDDEN).to(config.DEVICE) for _ in range(self.n_agents)]
        self.old_critics = [deepcopy(critic) for critic in self.critics]

        for model, actor, critic in zip(self.models, self.actors, self.critics):
            initialize_weights(model, mode='xavier')
            initialize_weights(actor, mode='xavier')
            initialize_weights(critic, mode='xavier')

        self.replay_buffers = DreamerMemory(config.CAPACITY, config.SEQ_LENGTH, config.ACTION_SIZE, config.IN_DIM, config.STATE_DIM, config.DEVICE, config.ENV_TYPE, 2)

        self.entropy = config.ENTROPY
        self.step_count = -1
        self.cur_update = 1
        self.accum_samples = 0
        self.total_samples = 0
        self.init_optimizers()
        self.n_agents = 2
        Path(config.LOG_FOLDER).mkdir(parents=True, exist_ok=True)
        global wandb
        import wandb
        wandb.init(dir=config.LOG_FOLDER)

    def init_optimizers(self):

        self.model_optimizers = [torch.optim.Adam(model.parameters(), lr=self.config.MODEL_LR) for model in self.models]
        self.actor_optimizers = [torch.optim.Adam(actor.parameters(), lr=self.config.ACTOR_LR, weight_decay=0.00001) for actor in self.actors]
        self.critic_optimizers = [torch.optim.Adam(critic.parameters(), lr=self.config.VALUE_LR) for critic in self.critics]


    def params(self):
        return {
            'models': [{k: v.cpu() for k, v in model.state_dict().items()} for model in self.models],
            'actors': [{k: v.cpu() for k, v in actor.state_dict().items()} for actor in self.actors],
            'critics': [{k: v.cpu() for k, v in critic.state_dict().items()} for critic in self.critics]
        }

    def step(self, rollouts, train_agent_id):
        if self.n_agents != rollouts['action'].shape[-2]:
            self.n_agents = rollouts['action'].shape[-2]


        self.accum_samples += len(rollouts['action'])

        self.total_samples += len(rollouts['action'])

        for agent_id in range(self.n_agents):
            self.replay_buffers.append(agent_id, rollouts['observation'], rollouts['action'], rollouts['reward'], rollouts['done'], rollouts['fake'], rollouts['last'], rollouts.get('avail_action'))

        self.step_count += 1
        

        if self.accum_samples < self.config.N_SAMPLES:
            return


        if len(self.replay_buffers) < self.config.MIN_BUFFER_SIZE:
            return

        self.accum_samples = 0
        sys.stdout.flush()

        for _ in range(self.config.MODEL_EPOCHS):
            samples = self.replay_buffers.sample(self.config.MODEL_BATCH_SIZE)
            self.train_model(samples, train_agent_id)


        for _ in range(self.config.EPOCHS):
            samples = self.replay_buffers.sample(self.config.BATCH_SIZE)
            self.train_agent(samples, train_agent_id)


    def train_model(self, samples, agent_id):
        self.models[agent_id].train()
        loss = model_loss(self.config, self.models[agent_id], samples['observation'], samples['action'], samples['av_action'], samples['reward'], samples['done'], samples['fake'], samples['last'])
        self.apply_optimizer(self.model_optimizers[agent_id], self.models[agent_id], loss, self.config.GRAD_CLIP)
        self.models[agent_id].eval()


    def train_agent(self, samples, agent_id):
        actions, av_actions, old_policy, agent_imag_feat, returns = actor_rollout(samples['observation'], samples['action'], samples['last'], self.models[agent_id], self.actors[agent_id], self.critics[agent_id] if self.config.ENV_TYPE == Env.STARCRAFT else self.old_critics[agent_id], self.config)
        adv = returns.detach() - self.critics[agent_id](agent_imag_feat, actions).detach()
        if self.config.ENV_TYPE == Env.STARCRAFT:
            adv = advantage(adv)
        wandb.log({'Agent{}_Returns'.format(agent_id): returns.mean()})

        
        for epoch in range(self.config.PPO_EPOCHS):
            inds = np.random.permutation(actions.shape[0])
            step = 2000
            for i in range(0, len(inds), step):
                self.cur_update += 1
                idx = inds[i:i + step]
                loss = actor_loss(agent_imag_feat[idx], actions[idx], av_actions[idx] if av_actions is not None else None, old_policy[idx], adv[idx], self.actors[agent_id], self.entropy)
                self.apply_optimizer(self.actor_optimizers[agent_id], self.actors[agent_id], loss, self.config.GRAD_CLIP_POLICY)
                self.entropy *= self.config.ENTROPY_ANNEALING
                val_loss = value_loss(self.critics[agent_id], actions[idx], agent_imag_feat[idx], returns[idx])
                if np.random.randint(20) == 9:
                    wandb.log({'Agent{}_val_loss'.format(agent_id): val_loss, 'Agent{}_actor_loss'.format(agent_id): loss})

                self.apply_optimizer(self.critic_optimizers[agent_id], self.critics[agent_id], val_loss, self.config.GRAD_CLIP_POLICY)
                if self.config.ENV_TYPE == Env.FLATLAND and self.cur_update % self.config.TARGET_UPDATE == 0:
                    self.old_critics[agent_id] = deepcopy(self.critics[agent_id])



    def apply_optimizer(self, opt, model, loss, grad_clip):
        opt.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
        opt.step()


    def fix_agent(self, agent_id):
        """Fix the parameters of the given agent (agent_id) by disabling gradient calculation."""
        for param in self.actors[agent_id].parameters():
            param.requires_grad = False
        for param in self.critics[agent_id].parameters():
            param.requires_grad = False
        for param in self.models[agent_id].parameters():
            param.requires_grad = False

    def unfix_agent(self, agent_id):
        """Unfix the parameters of the given agent (agent_id) by enabling gradient calculation."""
        for param in self.actors[agent_id].parameters():
            param.requires_grad = True
        for param in self.critics[agent_id].parameters():
            param.requires_grad = True
        for param in self.models[agent_id].parameters():
            param.requires_grad = True