import random
import sys
import copy

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as opt

import models

dev = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class Agent_MLP_GC_BC(nn.Module):
    def __init__(self, env, d_hidden, goal=None):
        super().__init__()
        self.H = env.H
        self.env = env
        self.policy_model = models.MLP([env.d_state*2]+d_hidden+[env.n_act])

        self.g = goal

        self.train()
    
    def set_goal(self, D):
        idx = D['rewards'].argmax()
        self.g = D['end_states'][idx]

    def policy(self, s):
        with torch.no_grad():
            self.eval()
            s = s.to(dev)
            logits = self.policy_model(torch.cat((s,self.g), dim=-1))

            a_probs = F.softmax(logits, dim=-1)
            a = torch.distributions.Categorical(a_probs).sample().item()
            self.train()
        return a
    
    def policy_logits(self, s, g):
        logits = self.policy_model(torch.cat((s,g), dim=-1))
        return logits

    def train_loss(self, D):
        return F.cross_entropy(self.policy_logits(D['states'], D['end_states']), D['actions'])
    
    def train_agent(self, N=1000, N_batch=100, lr=2e-3, n_iter_max=30000, weight_decay=1e-3):
        uniform_policy = lambda state : int(random.random()>0.5)
        optimizer = opt.AdamW(self.parameters(), lr=lr, weight_decay=weight_decay)
        Ds = []

        D,_ = self.env.collect_data(uniform_policy, N)
        D = {k:D[k].to(dev) for k in D.keys()}
        Ds = batchify(D, N_batch)
        self.set_goal(D)
        self.train()
        print("training...")
        for iter in range(n_iter_max):
            optimizer.zero_grad()
            loss = self.train_loss(Ds[iter%(N_batch-1)])
            loss.backward()
            optimizer.step()
            if iter%100==0:
                sys.stdout.write("\r\033[K")
                with torch.no_grad():
                    loss_ = self.train_loss(Ds[N_batch-1])
                print("\r\titeration : "+str(iter)+"  \tloss : "+str(loss.item())+"  \tgeneralization loss : "+str(loss_.item()), end='')
        ok = self.env.eval(self)
        print()

        return ok


class Agent_Q(nn.Module):
    def __init__(self, env, d_hidden):
        super().__init__()
        self.H = env.H
        self.env = env
        self.d_hidden = d_hidden
        self.Q_old = models.MLP([self.env.d_state+1]+d_hidden+[1])
        with torch.no_grad():
            self.Q_old.lin[-1].weight *= 0.
            self.Q_old.lin[-1].bias *= 0.
        self.Q = models.MLP([self.env.d_state+1]+d_hidden+[1])
        self.epsilon = None
        self.train()
    
    def policy(self, s):
        with torch.no_grad():
            self.eval()
            s = s.to(dev)
            actions_0 = torch.zeros(s.shape[:-1]+tuple([1]), device=dev)
            state_actions_0 = torch.cat((s,actions_0),dim=-1)

            actions_1 = torch.ones(s.shape[:-1]+tuple([1]), device=dev)
            state_actions_1 = torch.cat((s,actions_1),dim=-1)

            state_actions = torch.stack((state_actions_0, state_actions_1), dim=0)

            state_actions_Q = self.Q(state_actions).squeeze(-1)
            self.train()

        return torch.argmax(state_actions_Q, dim=0)

    def epsilon_greedy_policy(self, s):
        if random.random()<self.epsilon:
            return torch.randint(0, 2, s.shape[:-1])
        else:
            return self.policy(s)
    
    def train_loss(self, D):
        loss = 0.

        s = D['states']
        a = D['actions']
        r = D['rewards']
        s_next = D['next_states']

        state_action = torch.cat((s,a.unsqueeze(-1)),dim=-1)

        with torch.no_grad():
            actions_0 = torch.zeros(s_next.shape[:-1]+tuple([1]), device=dev)
            next_state_actions_0 = torch.cat((s_next,actions_0),dim=-1)
            actions_1 = torch.ones(s_next.shape[:-1]+tuple([1]), device=dev)
            next_state_actions_1 = torch.cat((s_next,actions_1),dim=-1)
            next_state_actions = torch.stack((next_state_actions_0, next_state_actions_1), dim=0)

            next_state_actions_Q = self.Q_old(next_state_actions)
            next_Q_max = torch.max(next_state_actions_Q, dim=0).values
        
        loss = F.mse_loss(self.Q(state_action), r + next_Q_max)
        return loss

    def train_agent(self, N=1000, N_batch=20, lr=1e-3, n_iter_max=50000, n_iter_Q_update=1000, epsilon=0.2, weight_decay=1e-2):
        self.epsilon = epsilon
        uniform_policy = lambda state : int(random.random()>0.5)
        optimizer = opt.AdamW(self.parameters(), lr=lr, weight_decay=weight_decay)
        D,_ = self.env.collect_data(uniform_policy, N)
        D = {k:D[k].to(dev) for k in D.keys()}
        Ds = batchify(D, N_batch)
        print("training...")
        for iter in range(n_iter_max):
            optimizer.zero_grad()
            loss = self.train_loss(Ds[iter%len(Ds)])
            if iter%100==0:
                sys.stdout.write("\r\033[K")
                print("\r\titeration : "+str(iter)+"  \tloss : "+str(loss.item()), end='', flush=True)
            loss.backward()
            optimizer.step()
            if (iter+1)%n_iter_Q_update==0:
                self.Q_old = copy.deepcopy(self.Q)
                D,_ = self.env.collect_data(self.epsilon_greedy_policy, N)
                D = {k:D[k].to(dev) for k in D.keys()}
                Ds += batchify(D, N_batch)
        print()
        ok = self.env.eval(self)
        print()

        return ok

class Agent_PPO(nn.Module):
    def __init__(self, env, d_hidden):
        super().__init__()
        self.H = env.H
        self.env = env
        self.d_hidden = d_hidden
        self.actor = models.MLP([self.env.d_state]+d_hidden+[env.n_act])
        self.critic = models.MLP([self.env.d_state]+d_hidden+[1])
        with torch.no_grad():
            self.actor.lin[-1].weight *= 0.
            self.actor.lin[-1].bias *= 0.
            self.critic.lin[-1].weight *= 0.
            self.critic.lin[-1].bias *= 0.
        self.train()
    
    def policy(self, s):
        with torch.no_grad():
            self.eval()
            s = s.to(dev)

            act_probs = torch.distributions.Categorical(F.softmax(self.actor(s), -1))
            a = act_probs.sample()

            self.train()

        return a.item()
    
    def policy_info(self, s, a):
        act_probs = torch.distributions.Categorical(F.softmax(self.actor(s), -1))

        return act_probs.log_prob(a), act_probs.entropy()
    
    def value(self, s):
        return self.critic(s)
    
    def dataset_with_info(self, D):
        s = D['states']
        a = D['actions']
        R = D['returns']
        with torch.no_grad():
            a_log_probs,_ = self.policy_info(s,a)
            advantages = R-self.value(s)
        D['action_log_probs'] = a_log_probs
        D['advantages'] = advantages

        return D
    
    def train_loss(self, D, epsilon, beta):
        loss = 0.

        s = D['states']
        a = D['actions']
        R = D['returns']
        old_action_log_probs = D['action_log_probs']
        advantages = D['advantages']

        a_log_probs, entropy = self.policy_info(s,a)
        values = self.value(s)

        ratios = torch.exp(a_log_probs - old_action_log_probs.detach())
        
        loss = -torch.min(ratios * advantages, ratios.clamp(1-epsilon, 1+epsilon) * advantages).mean() + F.mse_loss(values, R) - beta * entropy.mean()

        return loss

    def train_agent(self, N=1000, N_batch=10, lr=2e-4, n_iter_max=50, K=500, epsilon=0.1, beta=1e-3, weight_decay=1e-8):
       
        print("training...")
        for iter in range(n_iter_max):
            D,R = self.env.collect_data(self.policy, N)
            D = {k:D[k].to(dev) for k in D.keys()}
            D = self.dataset_with_info(D)
            Ds = batchify(D, N_batch)

            optimizer = opt.AdamW(self.parameters(), lr=lr, weight_decay=weight_decay)

            for k in range(K):
                optimizer.zero_grad()
                loss = self.train_loss(Ds[k%N_batch], epsilon, beta)
                loss.backward()
                optimizer.step()
            
            if iter%1==0:
                sys.stdout.write("\r\033[K")
                print("\r\titeration : "+str(iter)+"  \taverage returns : "+str(R), end='', flush=True)
            
        print()
        ok = self.env.eval(self)
        print()

        return ok


def batchify(D, n_batch):
    n = D['states'].size(0)
    p = torch.randperm(n)
    Ds = []
    n_by_batch = n//n_batch
    for i in range(n_batch):
        d = {k:D[k][p[(i*n_by_batch):((i+1)*n_by_batch)]] for k in D.keys()}
        Ds.append(d)

    return Ds