# adopted from https://github.com/seungeunrho/minimalRL

import gym
import collections
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from synthetic_env import synthetic_env


# simulation paraeters
seed = 0
torch.manual_seed(seed)
random.seed(seed)
print_interval = 10
num_epi=1000
phi_train_num_epi=1000
skip_phi_train=True

# agent hyperparameters
phi_lr = 1e-4
sf_lr = 1. #1e-3 #1e-6
gamma         = 0.95
buffer_limit  = 50000
batch_size    = 32
use_gpi = False
zero_shot = False
if zero_shot:
    use_gpi = False

# environment parameters
state_space=10000
action_space=4
state_dim=10
phi_dim=10
n_tasks=2
# task=0
n_steps=20

class ReplayBuffer():
    def __init__(self):
        self.buffer = collections.deque(maxlen=buffer_limit)
    
    def put(self, transition):
        self.buffer.append(transition)

    def reset(self):
        self.buffer = collections.deque(maxlen=buffer_limit)
    
    def sample(self, n):
        mini_batch = random.sample(self.buffer, n)
        s_lst, a_lst, r_lst, s_prime_lst, phi_lst, t_lst, done_mask_lst = [], [], [], [], [], [], []
        
        for transition in mini_batch: #s,a,r,s_prime, phi, t, done_mask
            s, a, r, s_prime, phi, t, done_mask = transition
            # print(phi.requires_grad)
            s_lst.append(s)
            a_lst.append([a])
            r_lst.append([r])
            s_prime_lst.append(s_prime)
            phi_lst.append(phi.clone())
            t_lst.append([t])
            done_mask_lst.append([done_mask])

        return torch.tensor(s_lst, dtype=torch.float), torch.tensor(a_lst), \
               torch.tensor(r_lst), torch.tensor(s_prime_lst, dtype=torch.float), \
               torch.stack(phi_lst), torch.tensor(t_lst), torch.tensor(done_mask_lst)
    
    def size(self):
        return len(self.buffer)

class SFQnet(nn.Module):
    def __init__(self, state_dim, action_space, phi_dim, n_tasks, gpi=True, true_w=True, w=None):
        """
        Creates a new SFDQN agent.
        
        Parameters
        ----------
        state_dim : int
            dimnetion of a state feature vector
        action_space:
            number of actions in the MDP
        phi_dim : int
            dimention of phi
        n_tasks : int
            number of tasks
        gpi : bool
            whether or not yo use generalized policy improvement (GPI) (defeult: True)
        true_w : bool
            whether or not to use true reward weights of the MDP for each task. If false, learn these reward weights (default: True)
        w :
            true reward weights of MDP for each task (default: None)
        """
        super(SFQnet, self).__init__()
        self.state_dim = state_dim
        self.action_space = action_space
        self.phi_dim = phi_dim
        self.n_tasks = n_tasks
        self.hidden_dim = 8
        self.gpi = gpi
        self.true_w = true_w
        self.w = w
        if self.w is None:
            self.w = torch.rand(n_tasks, phi_dim)


        # self.sfqn ={'fc1':[], 'fc2':[], 'fc3':[]}
        # nn.ParameterDict({'fc1':nn.ParameterList([]), 'fc2':nn.ParameterList([]), 'fc3':nn.ParameterList([])})
        # layer_params = {'fc1':[self.state_dim, self.hidden_dim], 'fc2':[self.hidden_dim, self.hidden_dim], 'fc3':[self.hidden_dim, self.action_space * self.phi_dim]}

        # model layers
        self.sfqnet = nn.ModuleDict({
            'fc1': nn.ModuleList([nn.Linear(state_dim, self.hidden_dim) for _ in range(n_tasks)]),
            'fc2': nn.ModuleList([nn.Linear(self.hidden_dim, self.hidden_dim) for _ in range(n_tasks)]),
            'fc3': nn.ModuleList([nn.Linear(self.hidden_dim, self.action_space * self.phi_dim) for _ in range(n_tasks)])
        })

        for key in list(self.sfqnet.keys()):
            for p in self.sfqnet[key][0].parameters():
                p.requires_grad = False

            # for p, p1 in zip(self.sfqnet[key][0].parameters(), self.sfqnet[key][1].parameters()):
            #     p1.data = p.data.clone()


            
    def forward(self, x, task):
        # print(x)
        x = F.relu(self.sfqnet['fc1'][task](x)) #F.tanh()
        # print('x1', x)
        # x = self.sfqnet['fc2'][task](x)
        # print('x2', x)
        x = self.sfqnet['fc3'][task](x)
        # print('x3', x)
        return x.reshape([-1, self.action_space, self.phi_dim])
    
    def getQ(self, state, task):
        if task==0:
            sf_out = self.forward(state, 0)
            q = sf_out.view(self.action_space, self.phi_dim) @ self.w[task].view(self.phi_dim, 1)
            return 0, q
        else:
            qs = []
            for i in range(self.n_tasks):
                sf_out = self.forward(state.clone(), i)
                # print('sf_out', i, sf_out)
                qs.append( sf_out.view(self.action_space, self.phi_dim) @ self.w[task].view(self.phi_dim, 1))
            qs = torch.stack(qs)
            t = qs.max(dim=1).values.argmax().item()

            return t, qs

    def get_next_actions(self, states, tasks, current_task):

        if current_task==0:
            sf_out = self.forward(states, 0)
            ws = torch.stack([self.w[current_task].view(-1, 1) for _ in states])
            q = torch.einsum('bij,bjk->bik', sf_out, ws)
            a1s = q.argmax(dim=1)

            return a1s, [], torch.tensor([0])

        # not needed since no training for other tasks
        else:
            # for updating current task SF
            qs = []
            ws = torch.stack([self.w[current_task].view(-1, 1) for _ in states])
            for i in range(self.n_tasks):
                sf_out = self.forward(states, i)
                qs.append( torch.einsum('bij,bjk->bik', sf_out, ws)) #sf_out @ ws
            qs = torch.stack(qs)
            qs = qs.reshape(-1, self.n_tasks, self.action_space)
            # print(qs.shape)
            ts = qs.max(dim=2).values.argmax(dim=1).detach()
            # print(ts.shape)
            a1s = torch.stack([qs[i, t_, :] for i, t_ in enumerate(ts) ]).argmax(dim=1)
            # print(a1s.shape)

        #     # for updating other task SF
        #     transfer_idxs = (tasks!=current_task).view(-1,)
        #     tasks_ = tasks[transfer_idxs]
        #     states_ = states[transfer_idxs]
        #     a2s = []
        #     for s_, t_ in zip(states_, tasks_):
        #         _, qs = self.getQ(s_, t_)
        #         a_ = qs[t_].argmax()
        #         a2s.append(a_)
        #     if a2s:
        #         a2s = torch.stack(a2s).detach()
            return a1s, [], torch.tensor([0])
      
    def sample_action(self, obs, epsilon, task):

        coin = random.random()
        if coin < epsilon:
            # print('epsilon', epsilon)
            return random.randint(0, self.action_space-1), task
        else :
            t, qs = self.getQ(obs, task)
            if task==0:
                # print(qs)
                return qs.view(-1).argmax().item(), t
            else:
                if not self.gpi:
                    # print(qs)
                    out = qs[task]
                    t = task
                else:
                    # print(qs, t)
                    out = qs[t]
                return out.argmax().item(), t
    
    def load_source_to_target(self):
        for key in list(self.sfqnet.keys()):
            for p, p1 in zip(self.sfqnet[key][0].parameters(), self.sfqnet[key][1].parameters()):
                p1.data = p.data.clone()
                p1.requires_grad = False

            
def train(sfdqn, sfdqn_target, memory, optimizer_sfdqn, optimizer_phi, current_task, env):
    # track learning of phi
    phi_loss = 0
    sf_loss = 0
    its = 10
    for i in range(its):
        s,a,r,s_prime,phi, t, done_mask = memory.sample(batch_size)

        a_prime1, a_prime2, idxs2 = sfdqn.get_next_actions(s_prime, t, current_task)

        # update current task SF
        sfdqn_out1 = torch.stack([sfdqn.forward(s_, current_task)[:, a_, :].view(-1) for s_, a_ in zip(s, a)])
        sfdqn_target_out1 = torch.stack([sfdqn.forward(s_, current_task)[:, a_, :].view(-1) for s_, a_ in zip(s_prime, a_prime1)]).detach()
        
        # calculate phi here
        phi = torch.stack([env.phi(s_)[:, a_, :].view(-1) for s_, a_ in zip(s, a)])
        if current_task==1:
            phi=phi.detach()

        target1 = (phi + gamma * sfdqn_target_out1 * done_mask).float()
        # print(target1.requires_grad)
        torch.autograd.set_detect_anomaly(True)
        loss1 = F.mse_loss(target1, sfdqn_out1)
        if current_task == 0:
            # pass
            optimizer_phi.zero_grad()
            loss1.backward(retain_graph=True)
            optimizer_phi.step()
            phi_loss += loss1.detach().item()
        else:
            # pass
            # print('before:')
            # _check_agent_weights(sfdqn)
            optimizer_sfdqn.zero_grad()
            loss1.backward(retain_graph=True)
            optimizer_sfdqn.step()
            # print('\nafter:')
            # _check_agent_weights(sfdqn)
            sf_loss += loss1.detach().item()
        torch.autograd.set_detect_anomaly(False)

    return phi_loss/its, sf_loss/its

# DEBUG
def _check_agent_grad(model):
    for name, p in model.named_parameters():
        try:
            print(name, 'grad', p.grad.data)
        except:
            print(None)

def _check_agent_weights(model):
    for name, p in model.named_parameters():
        try:
            print(name, 'weight', p.data)
        except:
            print(None)

def _check_agent_require_grad(model):
    for name, p in model.named_parameters():
        try:
            print(name, 'weight', p.requires_grad)
        except:
            print(None)

def main():
    env = synthetic_env(state_space=state_space, 
                 action_space=action_space, 
                 state_dim=state_dim,
                 phi_dim=phi_dim, 
                 gamma=gamma,
                 n_tasks=n_tasks,
                 seed=seed,
                 tildeP=False)
    sfdqn = SFQnet(state_dim=state_dim, action_space=action_space, phi_dim=phi_dim, n_tasks=n_tasks, gpi=use_gpi)
    sfdqn_target = SFQnet(state_dim=state_dim, action_space=action_space, phi_dim=phi_dim, n_tasks=n_tasks)
    sfdqn_target.load_state_dict(sfdqn.state_dict())
    memory = ReplayBuffer()

    # print(sfdqn)

    # _check_agent_require_grad(sfdqn)
    
    optimizer_sfdqn = optim.Adam(sfdqn.parameters(), lr=sf_lr)
    optimizer_phi = optim.Adam(env.phi.parameters(), lr=phi_lr)
    
    
    for task in range(n_tasks):

        if task==0 and skip_phi_train:
            continue

        if task==1:
            phi_name=f"phi_{gamma}_{state_space}_{state_dim}_{action_space}_{phi_dim}_{phi_train_num_epi}"
            env.phi.load_state_dict(torch.load(phi_name))
            if zero_shot:
                sfdqn.load_source_to_target()

        score = 0.0 
        gpi_percent = 0
        phi_loss=0
        sf_loss=0
        cum_score = 0.0
        tot_step_count = 0
        memory.reset()


        for n_epi in range(phi_train_num_epi if task == 0 else num_epi):
            epsilon = max(0.01, 0.08 - 0.01*(n_epi/200)) #Linear annealing from 8% to 1%
            if task==0:
                epsilon=1.0
            s, _ = env.reset()
            done = False

            step_count = 0
            while not done and step_count < n_steps:
                a, t = sfdqn.sample_action(torch.from_numpy(s).float(), epsilon, task)
                s_prime, r, phi, done, info = env.step(a, task) 
                # DEBUG
                # print('a:', a, 's_prime:', s_prime.dtype, 'r:', r, 'done:', type(done), 'info:', type(info))
                done_mask = 0.0 if done else 1.0
                memory.put((s,a,r,s_prime, phi.clone(), t, done_mask))
                s = s_prime
                # print('n_epi', n_epi, 'step_count', step_count, 's', env.state, 'a', a)

                score += r
                # print(r)
                step_count += 1
                if task == 1:
                    gpi_percent += int(t!=task)
                    tot_step_count += 1
                if done:
                    break
                
            if memory.size()>64 and not zero_shot:
                phi_loss, sf_loss = train(sfdqn, sfdqn_target, memory, optimizer_sfdqn, optimizer_phi, task, env)

            if n_epi%print_interval==0:
                if n_epi==0:
                    print("task : {}, n_episode : {}, score : {:.4f}, cum. score : {:.4f}, phi loss : {:.4f}, sf loss : {:.4f}, gpi percent : {:.2f}%, n_buffer : {}, eps : {:.1f}%".format(
                                                                task, n_epi, 0., 0., 0., 0., 0., memory.size(), epsilon*100))
                else:
                    gpi_percent_ = 0.0
                    if task==1:
                        cum_score += score
                        gpi_percent_ = gpi_percent/tot_step_count * 100
                        if n_epi%(10 * print_interval)==0:
                            sfdqn_target.load_state_dict(sfdqn.state_dict())
                    print("task : {}, n_episode : {}, score : {:.4f}, cum. score : {:.4f}, phi loss : {:.4f}, sf loss : {:.4f}, gpi percent : {:.2f}%, n_buffer : {}, eps : {:.1f}%".format(
                                                                    task, n_epi, score/(print_interval), cum_score/(n_epi), phi_loss, sf_loss, gpi_percent_, memory.size(), epsilon*100))
                    # print(_check_agent_weights(sfdqn))
                    # print()
                    score = 0.0

        if task==0:
            torch.save(env.phi.state_dict(), f'phi_{gamma}_{state_space}_{state_dim}_{action_space}_{phi_dim}_{phi_train_num_epi}')
    env.close()

def test_agent():
    env = synthetic_env(state_space=state_space, 
                 action_space=action_space, 
                 state_dim=state_dim,
                 phi_dim=phi_dim, 
                 gamma=gamma,
                 n_tasks=n_tasks,
                 seed=seed,
                 tildeP=False)

    
    sfdqn = SFQnet(state_dim=state_dim, action_space=action_space, phi_dim=phi_dim, n_tasks=n_tasks, gpi=use_gpi)
    sfdqn_target = SFQnet(state_dim=state_dim, action_space=action_space, phi_dim=phi_dim, n_tasks=n_tasks)
    sfdqn_target.load_state_dict(sfdqn.state_dict())    

    # _check_agent_weights(sfdqn)
    print()

    for i in range(1):
        s, _ = env.reset()
        print(s)
        print(sfdqn.forward(torch.from_numpy(s).float(), 0))

if __name__ == '__main__':
    main()
    # test_agent()