import os.path
import time
import yaml
import random
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
from kornia.augmentation import AugmentationSequential, RandomResizedCrop, RandomPlasmaShadow

from collections import deque

from stable_baselines3.common.buffers import ReplayBuffer
from torch.utils.tensorboard import SummaryWriter
from .utils import layer_init, linear_schedule


class Config:
    def __init__(self):
        self.max_steps = 2005000
        self.buffer_size = 100000
        self.learning_rate = 1e-4
        self.start_e = 1
        self.end_e = 0.01
        self.exploration_fraction = 0.05
        self.learning_starts = 5000
        self.batch_size = 32
        self.gamma = 0.99
        self.tau = 1
        self.target_network_freq = 1000
        self.train_freq = 4
        self.lambd = 0.0051
        self.crop_scale = 0.8
        self.cl_cft = 1e-2


def off_diagonal(x):
    # return a flattened view of the off-diagonal elements of a square matrix
    n, m = x.shape
    assert n == m
    return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()


class CNNQNetwork(nn.Module):
    def __init__(self, action_shape):
        super().__init__()
        self.backbone = nn.Sequential(
            layer_init(nn.Conv2d(4, 32, 8, stride=4)),
            nn.ReLU(inplace=True),
            layer_init(nn.Conv2d(32, 64, 4, stride=2)),
            nn.ReLU(inplace=True),
            layer_init(nn.Conv2d(64, 64, 3, stride=1)),
            nn.ReLU(inplace=True),
            nn.Flatten(),
            layer_init(nn.Linear(64 * 7 * 7, 512)),
            nn.ReLU(inplace=True),
        )
        self.q = nn.Sequential(
            layer_init(nn.Linear(512, 128)),
            nn.ReLU(inplace=True),
            layer_init(nn.Linear(128, action_shape)),
        )

        self.projector = nn.Sequential(
            nn.Linear(512 + action_shape, 512, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(512, 512, bias=False),
        )

    def encode(self, x):
        return self.backbone(x / 255.0)

    def forward(self, x):
        return self.q(self.backbone(x / 255.0))


class Policy:
    def __init__(self, args, observation_space, action_space):
        self.args = args
        obs_shape = observation_space.shape
        action_shape = action_space.n
        self.action_shape = action_shape
        if len(obs_shape) == 3:  # pixel obs
            self.q_network = CNNQNetwork(action_shape).to(self.args.device)
            self.target_network = CNNQNetwork(action_shape).to(self.args.device)

        self.optimizer = optim.Adam(self.q_network.parameters(), lr=self.args.learning_rate)
        self.target_network.load_state_dict(self.q_network.state_dict())
        self.loss_func = nn.MSELoss().to(self.args.device)

        self.aug = AugmentationSequential(
            RandomResizedCrop(size=(84, 84), scale=(self.args.crop_scale, 1)),
            RandomPlasmaShadow(roughness=(0.1, 0.7), p=1)
        )

    def select_action(self, obs):
        q_values = self.q_network(torch.Tensor(obs).to(self.args.device))
        actions = torch.argmax(q_values, dim=1).cpu().numpy()
        return actions

    def learn(self, data):
        with torch.no_grad():
            target_max, _ = self.target_network(data.next_observations).max(dim=1)
            td_target = data.rewards.flatten() + self.args.gamma * target_max * (1 - data.dones.flatten())
        old_val = self.q_network(data.observations).gather(1, data.actions).squeeze()
        loss = self.loss_func(td_target, old_val)
        bt_loss = self.barlow_twins(data.observations)
        loss += self.args.cl_cft * bt_loss

        # optimize the model
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        return loss, old_val.mean().item(), bt_loss.mean().item()

    def barlow_twins(self, obs):
        obs_anc = self.aug(obs.float())
        obs_pos = self.aug(obs.float())
        z1 = self.q_network(obs_anc)
        z2 = self.q_network(obs_pos)

        c = z1.T @ z2

        on_diag = torch.diagonal(c).add_(-1).pow_(2).sum()
        off_diag = off_diagonal(c).pow_(2).sum()
        loss = on_diag + self.args.lambd * off_diag
        return loss

    def update_target(self):
        for target_network_param, q_network_param in zip(self.target_network.parameters(), self.q_network.parameters()):
            target_network_param.data.copy_(
                self.args.tau * q_network_param.data + (1.0 - self.args.tau) * target_network_param.data
            )

    def save_model(self, save_path):
        torch.save(self.q_network.state_dict(), save_path)


class Agent:
    def __init__(self, args, envs, eval_env):
        self.args = Config()
        self.args.__dict__.update(args.__dict__)
        with open(self.args.args_path, "w+") as file:
            yaml.dump(self.args, file)
        self.envs = envs
        self.eval_env = eval_env
        print(self.envs.single_observation_space)
        print(self.envs.single_action_space)

        self.writer = SummaryWriter(args.tb_path, flush_secs=2)
        self.policy = Policy(self.args, self.envs.single_observation_space, self.envs.single_action_space)
        self.buffer = ReplayBuffer(self.args.buffer_size,
                                   self.envs.single_observation_space, self.envs.single_action_space,
                                   self.args.device,
                                   n_envs=args.num_envs,
                                   optimize_memory_usage=True,
                                   handle_timeout_termination=False, )

    @torch.no_grad()
    def eval(self):
        obs, _ = self.eval_env.reset()
        while True:
            action = self.policy.select_action(obs)
            obs, _, _, _, infos = self.eval_env.step(action)
            if "final_info" in infos:
                for info in infos["final_info"]:
                    # Skip the envs that are not done
                    if "episode" not in info:
                        continue
                    return info['episode']['r'], info['episode']['l']

    def run(self):
        eval_rewards = deque(maxlen=10)
        start_time = time.time()
        obs, _ = self.envs.reset()
        for global_step in range(self.args.max_steps):
            epsilon = linear_schedule(self.args.start_e, self.args.end_e,
                                      self.args.exploration_fraction * self.args.max_steps, global_step)
            if random.random() < epsilon:
                actions = np.array([self.envs.single_action_space.sample() for _ in range(self.args.num_envs)])
            else:
                actions = self.policy.select_action(obs)
            next_obs, rewards, terminated, truncated, infos = self.envs.step(actions)
            if "final_info" in infos:
                for info in infos["final_info"]:
                    # Skip the envs that are not done
                    if "episode" not in info:
                        continue
                    eval_rewards.append(info["episode"]["r"])
                    print(f"global_step={global_step}, "
                          f"episodic_reward={info['episode']['r']}, "
                          f"episodic_length={info['episode']['l']}, "
                          f"time_used={(time.time() - start_time)}")
                    self.writer.add_scalar("charts/episodic_reward", info["episode"]["r"], global_step)
                    self.writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
                    self.writer.add_scalar("charts/epsilon", epsilon, global_step)
            real_next_obs = next_obs.copy()
            for idx, d in enumerate(truncated):
                if d:
                    real_next_obs[idx] = infos["final_observation"][idx]
            self.buffer.add(obs, real_next_obs, actions, rewards, terminated, infos)
            obs = next_obs

            if global_step > self.args.learning_starts and global_step % self.args.train_freq == 0:
                data = self.buffer.sample(self.args.batch_size)
                loss, old_val, bt_loss = self.policy.learn(data)
                if global_step % 100 == 0:
                    self.writer.add_scalar("losses/td_loss", loss, global_step)
                    self.writer.add_scalar("losses/q_values", old_val, global_step)
                    self.writer.add_scalar("losses/bt_values", bt_loss, global_step)
                    self.writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)

            if global_step % self.args.target_network_freq == 0:
                self.policy.update_target()

        return np.stack(eval_rewards).mean(), np.stack(eval_rewards).std()
