import psutil
import os
import torch.nn.functional as F
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision
from torchvision import transforms, datasets
import os, PIL, pathlib, warnings
from torchstat import stat
import pathlib
from torch.optim.lr_scheduler import CosineAnnealingLR
import yaml
# temp = pathlib.PosixPath
# pathlib.PosixPath = pathlib.WindowsPath
import gym
import time
import random
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
matplotlib.use('TkAgg')
from collections import namedtuple
from PIL import Image
from torchinfo import summary
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T
from torchstat import stat
from thop import profile

before_mem = psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024

random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed(42)
torch.backends.cudnn.benchmark = False # This can slow down training
torch.backends.cudnn.deterministic = True

env = gym.make('GridWorld-v1').unwrapped
# print(env.action_space)
BATCH_SIZE = 32
GAMMA = 0.999
TARGET_UPDATE = 100


# if gpu is to be used
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward'))

from tensorboardX import SummaryWriter
writer = SummaryWriter("runs/tt/reward")
writer0 = SummaryWriter("runs/tt/q_value")
writer1 = SummaryWriter("runs/tt/loss")
writer2 = SummaryWriter("runs/tt/ps_mem")



class ReplayMemory(object):

    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def push(self, *args):
        """Saves a transition."""
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

def autopad(k, p=None):  # kernel, padding
    # Pad to 'same'
    if p is None:
        p = k // 2 if isinstance(k, int) else [x // 2 for x in k]  # auto-pad
    return p

class Conv(nn.Module):
    # Standard convolution
    def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):  # ch_in, ch_out, kernel, stride, padding, groups
        super().__init__()
        self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
        self.bn = nn.BatchNorm2d(c2, eps=0.001, momentum=0.03)
        self.act = nn.SiLU(inplace=True)
            # if act is True else (act if isinstance(act, nn.Module) else nn.Identity())

    def forward(self, x):
        return self.act(self.bn(self.conv(x)))


class Bottleneck(nn.Module):
    # Standard bottleneck
    def __init__(self, c1, c2, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, shortcut, groups, expansion
        super().__init__()
        c_ = int(c2 * e)  # hidden channels
        self.cv1 = Conv(c1, c_, 1, 1)
        self.cv2 = Conv(c_, c2, 3, 1, g=g)
        self.add = shortcut and c1 == c2

    def forward(self, x):
        return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))


class C3(nn.Module):
    # CSP Bottleneck with 3 convolutions
    def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion
        super().__init__()
        c_ = int(c2 * e)  # hidden channels
        self.cv1 = Conv(c1, c_, 1, 1)
        self.cv2 = Conv(c1, c_, 1, 1)
        self.cv3 = Conv(2 * c_, c2, 1)  # act=FReLU(c2)
        self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))

    def forward(self, x):
        return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1))


class SPPF(nn.Module):
    # Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher
    def __init__(self, c1, c2, k=5):  # equivalent to SPP(k=(5, 9, 13))
        super().__init__()
        c_ = c1 // 2  # hidden channels
        self.cv1 = Conv(c1, c_, 1, 1)
        self.cv2 = Conv(c_ * 4, c2, 1, 1)
        self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)

    def forward(self, x):
        x = self.cv1(x)
        with warnings.catch_warnings():
            warnings.simplefilter('ignore')  # suppress torch 1.9.0 max_pool2d() warning
            y1 = self.m(x)
            y2 = self.m(y1)
            return self.cv2(torch.cat([x, y1, y2, self.m(y2)], 1))



#
# n 7  6 6
# s 14 8 8
# m 28 14 8
# l 48 24 12
# x 72 36 18
class DQN(nn.Module):
    def __init__(self, h, w, outputs=4, n_features=102400, n_neuron1=7, n_neuron2=6, n_neuron3=6):
        super(DQN, self).__init__()

        # self.Conv_1 = Conv(3, 32, 6, 2, 2)
        # self.Conv_2 = Conv(32, 64, 3, 2)
        # self.C3_3 = C3(64, 64)
        # self.Conv_4 = Conv(64, 128, 3, 2)
        # self.C3_5 = C3(128, 128, n=2)
        # self.Conv_6 = Conv(128, 256, 3, 2)
        # self.C3_7 = C3(256, 256, n=3)
        # self.Conv_8 = Conv(256, 512, 3, 2)
        # self.C3_9 = C3(512, 512)
        # self.SPPF = SPPF(512, 512, 5)
        # yolov5n
        # self.model = nn.Sequential(
        #     Conv(3, 16, 6, 2, 2),
        #     Conv(16, 32, 3, 2),
        #     C3(32, 32),
        #     Conv(32, 64, 3, 2),
        #     ####
        #     C3(64, 64, n=2),
        #     Conv(64, 128, 3, 2),
        #     ####
        #     C3(128, 128, n=3),
        #     Conv(128, 256, 3, 2),
        #     ####
        #     C3(256, 256),
        #     ###
        #     SPPF(256, 256, 5)
        # )
        # yolov5s
        # self.model = nn.Sequential(
        #     Conv(3, 32, 6, 2, 2),
        #     Conv(32, 64, 3, 2),
        #     C3(64, 64),
        #     Conv(64, 128, 3, 2),
        #     # C3(128, 128, n=2),
        #     Conv(128, 256, 3, 2),
        #     # C3(256, 256, n=3),
        #     Conv(256, 512, 3, 2),
        #     C3(512, 512),
        #     SPPF(512, 512, 5)
        # )
        # yolov5m
        # self.model = nn.Sequential(
        #     Conv(3, 48, 6, 2, 2),
        #     Conv(48, 96, 3, 2),
        #     C3(96, 96, n=2),
        #     Conv(96, 192, 3, 2),
        #     C3(192, 192, n=4),
        #     Conv(192, 384, 3, 2),
        #     C3(384, 384, n=6),
        #     Conv(384, 768, 3, 2),
        #     C3(768, 768, n=2),
        #     SPPF(768, 768, 5)
        # )
        # # yolov5l
        self.model = nn.Sequential(
            Conv(3, 64, 6, 2, 2),
            Conv(64, 128, 3, 2),
            C3(128, 128, n=3),
            Conv(128, 256, 3, 2),
            C3(256, 256, n=6),
            Conv(256, 512, 3, 2),
            C3(512, 512, n=9),
            Conv(512, 1024, 3, 2),
            C3(1024, 1024, n=3),
            SPPF(1024, 1024, 5)
        )
        # # yolov5x
        # self.model = nn.Sequential(
        #     Conv(3, 80, 6, 2, 2),
        #     Conv(80, 160, 3, 2),
        #     C3(160, 160, n=4),
        #     Conv(160, 320, 3, 2),
        #     C3(320, 320, n=8),
        #     Conv(320, 640, 3, 2),
        #     C3(640, 640, n=12),
        #     Conv(640, 1280, 3, 2),
        #     C3(1280, 1280, n=4),
        #     SPPF(1280, 1280, 5)

        # self.classifier = nn.Sequential(
        #     nn.Linear(in_features=25600, out_features=2560),
        #     nn.ReLU(),
        #     nn.Linear(in_features=2560, out_features=1024),
        #     nn.ReLU(),
        #     nn.Linear(in_features=1024, out_features=256),
        #     nn.ReLU(),
        #     nn.Linear(in_features=256, out_features=64),
        #     nn.ReLU(),
        #     nn.Linear(in_features=64, out_features=4)
        # )
        # self.conv1 = nn.Conv2d(256, 128, kernel_size=3, stride=2)
        # self.bn1 = nn.BatchNorm2d(128)
        # self.conv2 = nn.Conv2d(128, 32, kernel_size=3, stride=2)
        # self.bn2 = nn.BatchNorm2d(32)
        #
        # self.head1 = nn.Linear(128, 32)
        # self.head2 = nn.Linear(32, outputs)
        self.classifier = nn.Sequential(
            nn.Conv2d(1024, 512, kernel_size=3, stride=2),
            # nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(512, 128, kernel_size=3, stride=2),
            # nn.BatchNorm2d(32),
            nn.ReLU(),
            # nn.Linear(in_features=1024, out_features=512, bias=True),
            # nn.ReLU(),
            # nn.Linear(in_features=512, out_features=outputs, bias=True),
            # nn.Linear(in_features=n_neuron2, out_features=n_neuron3, bias=True),
            # nn.ReLU(),
            # nn.Linear(in_features=n_neuron3, out_features=outputs, bias=True),
        )
        self.head1 = nn.Linear(512, 64)
        self.head2 = nn.Linear(64, outputs)

    def forward(self, x):
        # x = self.Conv_1(x)
        # x = self.Conv_2(x)
        # x = self.C3_3(x)
        # x = self.Conv_4(x)
        # x = self.C3_5(x)
        # x = self.Conv_6(x)
        # x = self.C3_7(x)
        # x = self.Conv_8(x)
        # x = self.C3_9(x)
        # x = self.SPPF(x)
        x = self.model(x)
        # x = F.relu(self.bn1(self.conv1(x)))
        # x = F.relu(self.bn2(self.conv2(x)))

        # x = torch.flatten(x, start_dim=1)
        # print(x.size())
        # x = F.relu(self.head1(x))
        x = self.classifier(x)
        x = torch.flatten(x, start_dim=1)
        x = F.relu(self.head1(x))
        x = self.head2(x)
        # x = F.relu(self.head2(x))

        return x







resize = T.Compose([T.ToPILImage(),
                    T.Resize(400, interpolation=Image.BICUBIC),
                    T.ToTensor()])  # Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor

def trans(scn):


    scn = scn.transpose((2,0,1))
    scn = np.ascontiguousarray(scn, dtype=np.float32) / 255
    scn = torch.from_numpy(scn)
    # BCHW
    return resize(scn).unsqueeze(0).to(device)

def select_action(state, episode, epsilon_coefficient=0.4):
    # if episode < 100:
    #     return torch.tensor([[random.randrange(n_actions)]], device=device, dtype=torch.long)
    epsilon = epsilon_coefficient * (0.99 ** (episode-100))
    if epsilon < 0.0001:
        epsilon = 0.0001
    if epsilon <= np.random.uniform(0, 1):
        with torch.no_grad():
            return policy_net(state).max(1)[1].view(1, 1)
    else:
        return torch.tensor([[random.randrange(n_actions)]], device=device, dtype=torch.long)

def learn():
    # print("learn")
    if len(memory) < BATCH_SIZE:
        return 0
    transitions = memory.sample(BATCH_SIZE)
    batch = Transition(*zip(*transitions))

    # Compute a mask of non-final states and concatenate the batch elements
    # (a final state would've been the one after which simulation ended)
    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
                                          batch.next_state)), device=device, dtype=torch.bool)
    non_final_next_states = torch.cat([s for s in batch.next_state
                                                if s is not None])
    state_batch = torch.cat(batch.state).to(device)
    action_batch = torch.cat(batch.action).to(device)
    reward_batch = torch.cat(batch.reward).to(device)

    # Compute Q(s_t, a) - the model computes Q(s_t), then we select the
    # columns of actions taken. These are the actions which would've been taken
    # for each batch state according to policy_net
    state_action_values = policy_net(state_batch).gather(1, action_batch).to(device)

    # Compute V(s_{t+1}) for all next states.
    # Expected values of actions for non_final_next_states are computed based
    # on the "older" target_net; selecting their best reward with max(1)[0].
    # This is merged based on the mask, such that we'll have either the expected
    # state value or 0 in case the state was final.
    next_state_values = torch.zeros(BATCH_SIZE, device=device)
    next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0].detach()
    # Compute the expected Q values
    expected_state_action_values = (next_state_values * GAMMA) + reward_batch

    # Compute Huber loss
    loss = F.smooth_l1_loss(state_action_values, expected_state_action_values.unsqueeze(1))
    loss_val = loss.item()
    # Optimize the model
    optimizer.zero_grad()
    loss.backward()
    for param in policy_net.parameters():
        if param.grad is not None:
            param.grad.data.clamp_(-1, 1)
    # print("back propogation")
    optimizer.step()
    return loss_val


transfer = dict()
size = 10
for i in range(size, size * size):
    transfer[str(i) + '_0'] = i - size

for i in range(size * (size - 1)):
    transfer[str(i) + '_1'] = i + size

for i in range(1, size * size):
    if i % size == 0:
        continue
    transfer[str(i) + '_2'] = i - 1

for i in range(size * size):
    if (i + 1) % size == 0:
        continue
    transfer[str(i) + '_3'] = i + 1

print(transfer)


num_episodes = 5000
max_number_of_steps = 600

goal_average_steps = 65
num_consecutive_iterations = 50
last_time_steps = np.zeros(num_consecutive_iterations)
env.reset(0)
# while 1:
image = env.render(mode='rgb_array')
# plt.imshow(image)
init_image = trans(image)
# plt.figure()
# plt.imshow(init_image.cpu().squeeze(0).permute(1, 2, 0).numpy(),
#            interpolation='none')
# plt.show()
_, _, image_height, image_width = init_image.shape
print("height: ", image_height)
print("width: ", image_width)

# Get number of actions from gym action space
n_actions = len(env.actions)

policy_net = DQN(image_height, image_width, n_actions).to(device)
stat(policy_net, (3, 400, 400))

target_net = DQN(image_height, image_width, n_actions).to(device)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()

optimizer = torch.optim.Adam(policy_net.parameters(), lr=0.01)
memory = ReplayMemory(1000000)

goal_state, goal_action, goal_next_state, goal_reward = None, None, None, None
timer = time.time()
ss = 0

# num_episodes = 100
for i_episode in range(num_episodes):
    # Initialize the environment and state
    print(i_episode)
    env.reset(i_episode)
    episode_reward = 0
    state = trans(image)
    # print(state)
    pos = 0
    pos_flag = [0 for _ in range(100)]
    if goal_reward:
        for i in range(10):
            memory.push(goal_state, goal_action, goal_next_state, goal_reward)
    for t in range(max_number_of_steps):
        ss += 1
        info = psutil.virtual_memory()
        ps_mem = psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 - before_mem
        total_mem = info.total / 1024 / 1024
        percent_mem = info.percent
        writer2.add_scalar("memory", ps_mem, ss)
        # writer3.add_scalar("memory", total_mem, ss)
        # writer4.add_scalar("memory", percent_mem, ss)
        print(ps_mem, total_mem, percent_mem)
        # sys.getsizeof(df) / 1024 / 1024

        Q_value = policy_net(init_image).detach().cpu().numpy()[0]
        writer0.add_scalar("q_value", (Q_value[0] - np.mean(Q_value)) / np.std(Q_value), i_episode)

        # Select and perform an action
        image = env.render(mode='rgb_array')
        action = select_action(state, i_episode)
        print("action: ", action)

        key = "%d_%d" % (pos, action.item())
        print("key=", key)
        while key not in transfer:  #
            action = random.randint(0,3)
            key = "%d_%d" % (pos, action)

        next_pos, reward, done, _ = env.step(action)
        action = torch.tensor([[int(action)]], dtype=torch.int64)
        # print(action)

        reward = torch.tensor([reward], device=device)

        # Observe new state
        if not done:
            next_state = trans(image)

        else:
            next_state = None

        if reward == 100:
            print(state, action, next_state, reward)

        # Store the transition in memory
        if not pos_flag[pos]:
            if reward == 100:
                goal_state, goal_action, goal_next_state, goal_reward = state, action, next_state, reward
                for i in range(30):
                    memory.push(state, action, next_state, reward)
            memory.push(state, action, next_state, reward)
            # print(state, action, next_state, reward)
            pos_flag[pos] = 1
        pos = next_pos

        # Move to the next state
        state = next_state

        # Perform one step of the optimization (on the target network)
        if i_episode >100:
            loss = learn()
            writer1.add_scalar("loss", loss, ss)
        episode_reward += reward.item()

        # time.sleep(0.5)

        if done:
            print(
                'Episodes: %d, steps: %d, episode_reward：%d, score: %f' % (i_episode, t + 1, reward.item(), last_time_steps.mean()))
            last_time_steps = np.hstack((last_time_steps[1:], [reward.item()]))
            writer.add_scalar("reward_show", reward.item(), ss)
            break

    # Update the target network, copying all weights and biases in DQN
    if i_episode % TARGET_UPDATE == 0:
        target_net.load_state_dict(policy_net.state_dict())

    if t == max_number_of_steps - 1:
        print('Episodes: %d, steps: %d, episode_reward：%d, score: %f' % (i_episode, t + 1, episode_reward, last_time_steps.mean()))
        last_time_steps = np.hstack((last_time_steps[1:], [episode_reward]))


    if (last_time_steps.mean() >= goal_average_steps):
        print('Time %d s, episodes %d.' % (time.time() - timer, i_episode))

        env.close()
        break

print('all step: ', ss)
print('Complete')
env.close()

