# adopted from https://github.com/seungeunrho/minimalRL

import gym
import collections
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from synthetic_env import synthetic_env

import argparse

# parse arguments from command line
parser = argparse.ArgumentParser(description='Synthetic Succesor Feature Deep Q-learning')
parser.add_argument('--seed', default=0, type=int, help='seed')
parser.add_argument('--c', default=0.01, type=float, help='c')
parser.add_argument('--gamma', default=0.95, type=float, help='gamma')

args = parser.parse_args()

# simulation paraeters
seed = args.seed
torch.manual_seed(seed)
random.seed(seed)
print_interval = 10
num_epi=1000
phi_train_num_epi=1000
skip_phi_train=True

# agent hyperparameters

dqn_lr = 1e-1 #1e-2 #1e-3 #1e-6
gamma         = args.gamma
buffer_limit  = 50000
batch_size    = 32
use_gpi = True
zero_shot = False
if zero_shot:
    use_gpi = False

# environment parameters
state_space=10000
action_space=4
state_dim=10
phi_dim=10
n_tasks=2
# task=0
n_steps=20
c = args.c

class ReplayBuffer():
    def __init__(self):
        self.buffer = collections.deque(maxlen=buffer_limit)
    
    def put(self, transition):
        self.buffer.append(transition)
    
    def sample(self, n):
        mini_batch = random.sample(self.buffer, n)
        s_lst, a_lst, r_lst, s_prime_lst, done_mask_lst = [], [], [], [], []
        
        for transition in mini_batch:
            s, a, r, s_prime, done_mask = transition
            s_lst.append(s)
            a_lst.append([a])
            r_lst.append([r])
            s_prime_lst.append(s_prime)
            done_mask_lst.append([done_mask])

        return torch.tensor(s_lst, dtype=torch.float), torch.tensor(a_lst), \
               torch.tensor(r_lst), torch.tensor(s_prime_lst, dtype=torch.float), \
               torch.tensor(done_mask_lst)
    
    def size(self):
        return len(self.buffer)

    def reset(self):
        self.buffer = collections.deque(maxlen=buffer_limit)

class Qnet(nn.Module):
    def __init__(self, state_dim, action_space):
        super(Qnet, self).__init__()
        self.state_dim = state_dim
        self.action_space = action_space
        self.hidden_dim = 8

        # model layers
        self.fc1 = nn.Linear(self.state_dim, self.hidden_dim)
        self.fc2 = nn.Linear(self.hidden_dim, self.hidden_dim)
        self.fc3 = nn.Linear(self.hidden_dim, self.action_space)

        # self.q1 = None

    def forward(self, x):
        x = F.relu(self.fc1(x))
        # x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
      
    def sample_action(self, obs, epsilon, q1):
        out = self.forward(obs)
        coin = random.random()
        if coin < epsilon:
            return random.randint(0, self.action_space-1)
        else : 
            qs = []
            qs.append(q1(obs))
            # print(qs)
            qs.append(out)
            # print(qs)
            qs = torch.stack(qs)
            # print(qs)
            t = qs.max(dim=1).values.argmax().item()
            return qs[t].argmax().item()
            
def train(q, q_target, memory, optimizer, q1):
    dqn_loss = 0
    its=10
    for i in range(its):
        s,a,r,s_prime,done_mask = memory.sample(batch_size)

        # #DEBUG
        # print('s', s.requires_grad)
        # print('a', a.requires_grad)
        # print('s_prime', s_prime.requires_grad)
        # print('done_mask', done_mask.requires_grad)

        q_out = q(s)

        q_a = q_out.gather(1,a)
        qs_ = []
        qs_.append(q_target(s_prime))
        qs_.append(q1(s_prime))
        qs_ = torch.stack(qs_)
        qs_ = qs_.reshape(-1, n_tasks, action_space)
        ts = qs_.max(dim=2).values.argmax(dim=1).detach()
        a1s = torch.stack([qs_[i, t_, :] for i, t_ in enumerate(ts) ]).argmax(dim=1)
        # print(q_target.forward(s_prime[0]).shape)
        max_q_prime = torch.stack([q.forward(s_)[a_].view(-1) for s_, a_ in zip(s_prime, a1s)]).detach()

        # max_q_prime = q_target(s_prime).max(1)[0].unsqueeze(1)
        target = r + gamma * max_q_prime * done_mask
        loss = F.mse_loss(q_a, target) #smooth_l1_loss
        # print(loss.item())

        dqn_loss += loss.detach().item()

        # DEBUG
        optimizer.zero_grad()
        # print('\ngradient before backward')
        # _check_agent_grad(q)
        loss.backward()
        # print('\ngradient after backward')
        # _check_agent_grad(q)
        # print('\nweights before step')
        # _check_agent_weights(q)
        optimizer.step()
        # print('\nweights after step')
        # _check_agent_weights(q)
        # print()
    return dqn_loss/its

# DEBUG
def _check_agent_grad(model):
    for name, p in model.named_parameters():
        try:
            print(name, 'grad', p.grad.data)
        except:
            print(None)

def _check_agent_weights(model):
    for name, p in model.named_parameters():
        try:
            print(name, 'weight', p.data)
        except:
            print(None)

def _check_agent_require_grad(model):
    for name, p in model.named_parameters():
        try:
            print(name, 'weight', p.requires_grad)
        except:
            print(None)

def main():
    torch.manual_seed(seed)
    env = synthetic_env(state_space=state_space, 
                 action_space=action_space, 
                 state_dim=state_dim,
                 phi_dim=phi_dim, 
                 gamma=gamma,
                 n_tasks=n_tasks,
                 seed=0, c=c,
                 tildeP=False)
    q = Qnet(state_dim=state_dim, action_space=action_space)
    q_target = Qnet(state_dim=state_dim, action_space=action_space)
    q_target.load_state_dict(q.state_dict())
    memory = ReplayBuffer()
    q1 = Qnet(state_dim=state_dim, action_space=action_space)
    q1_name = f'dqn_0.0-{gamma}-0_eps_0.5'
    q1.load_state_dict(torch.load(q1_name))

    optimizer = optim.Adam(q.parameters(), lr=dqn_lr)

    for task in range(n_tasks):
        if task==0:
            continue

        if task==1:
            phi_name=f"phi_{gamma}_{state_space}_{state_dim}_{action_space}_{phi_dim}_{phi_train_num_epi}_0_eps_0"
            env.phi.load_state_dict(torch.load(phi_name))
        
        score = 0.0 
        cum_score = 0.0
        tot_step_count = 0
        memory.reset()

        for n_epi in range(num_epi):
            
            epsilon = max(0.0, 0.5 - 0.5*(n_epi/200))
            # epsilon = max(0.01, 0.08 - 0.01*(n_epi/200)) #Linear annealing from 8% to 1%
            s, _ = env.reset()
            done = False

            step_count = 0
            while not done and step_count < n_steps:
                a = q.sample_action(torch.from_numpy(s).float(), epsilon, q1)
                s_prime, r, phi, done, info = env.step(a, task) 
                # print('a:', a, 's_prime:', s_prime.dtype, 'r:', r, 'done:', type(done), 'info:', type(info))
                done_mask = 0.0 if done else 1.0
                memory.put((s,a,r,s_prime, done_mask))
                s = s_prime

                score += r
                step_count += 1

                if done:
                    break
                
            if memory.size()>64:
                dqn_loss = train(q, q_target, memory, optimizer, q1)
                # print("task : {}, n_episode : {}, score : {:.4f}, cum. score : {:.4f}, phi loss : {:.4f}, dqn loss : {:.4f}, gpi percent : {:.2f}%, n_buffer : {}, eps : {:.1f}%".format(
                                                                # task, n_epi, score/(print_interval), cum_score/(n_epi), 0, dqn_loss, 0, memory.size(), epsilon*100))

            if n_epi%print_interval==0 and n_epi!=0:
                if n_epi==0:
                    print("task : {}, n_episode : {}, score : {:.4f}, cum. score : {:.4f}, phi loss : {:.4f}, sf loss : {:.4f}, gpi percent : {:.2f}%, n_buffer : {}, eps : {:.1f}%".format(
                                                                task, n_epi, score, score, 0., 0., 0., memory.size(), epsilon*100))
                    cum_score += score
                    score = 0.0
                else:
                    cum_score += score
                    q_target.load_state_dict(q.state_dict())
                    print("task : {}, n_episode : {}, score : {:.4f}, cum. score : {:.4f}, phi loss : {:.4f}, dqn loss : {:.4f}, gpi percent : {:.2f}%, n_buffer : {}, eps : {:.1f}%".format(
                                                                task, n_epi, score/(print_interval), cum_score/(n_epi+1), 0, dqn_loss, 0, memory.size(), epsilon*100))
                    score = 0.0
        if task==1:
            pass
            # dqn_name=f"dqn_{c}-{seed}"
            # torch.save(q.state_dict(), dqn_name)
    env.close()

def test_agent():
    env = synthetic_env(state_space=state_space, 
                 action_space=action_space, 
                 state_dim=state_dim,
                 phi_dim=phi_dim, 
                 gamma=gamma,
                 n_tasks=n_tasks,
                 seed=seed,
                 tildeP=False)

    
    sfdqn = SFQnet(state_dim=state_dim, action_space=action_space, phi_dim=phi_dim, n_tasks=n_tasks, gpi=use_gpi)
    sfdqn_target = SFQnet(state_dim=state_dim, action_space=action_space, phi_dim=phi_dim, n_tasks=n_tasks)
    sfdqn_target.load_state_dict(sfdqn.state_dict())    

    # _check_agent_weights(sfdqn)
    print()

    for i in range(1):
        s, _ = env.reset()
        print(s)
        print(sfdqn.forward(torch.from_numpy(s).float(), 0))

if __name__ == '__main__':
    main()