import random
import sys
import copy
import time
import gc

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as opt

import models
import tree_search

dev = torch.device('cpu')

class Agent_AlphaZero(nn.Module):
    def __init__(self, env, d_hidden, C=1., N_search=50, T=1.):
        super().__init__()
        self.H = env.H
        self.env = env
        self.C, self.N_search, self.T = C, N_search, T
        self.mlp = models.MLP([env.d_state]+d_hidden)
        self.policy_head = nn.Sequential(nn.ReLU(), nn.Linear(d_hidden[-1], env.n_act))
        self.value_head = nn.Sequential(nn.ReLU(), nn.Linear(d_hidden[-1], 1))

        with torch.no_grad():
            self.policy_head[-1].weight *= 1e-2
            self.policy_head[-1].bias *= 0.
            self.value_head[-1].weight *= 1e-2
            self.policy_head[-1].bias *= 0.

        self.mlp_old = None
        self.policy_head_old = None
        self.value_head_old = None
        self.update_networks()

    def update_networks(self):
        self.mlp_old = copy.deepcopy(self.mlp)
        self.policy_head_old = copy.deepcopy(self.policy_head)
        self.value_head_old = copy.deepcopy(self.value_head)

    def policy(self, s):
        self.eval()
        with torch.no_grad():
            a_probs = tree_search.tree_search(self, s, self.env.n_act, self.N_search, self.T, self.C)
            a = torch.distributions.Categorical(a_probs).sample().item()

        return a,a_probs.tolist()

    
    def forward(self, s):
        h = self.mlp_old(s)
        act_probs = F.softmax(self.policy_head_old(h), dim=-1)
        v = torch.sigmoid(self.value_head_old(h))

        return act_probs, v

    def train_loss(self, D):
        s = D['states']
        act_probs_target = D['policy_data']
        v_target = D['returns']

        h = self.mlp(s)
        act_probs_logits = self.policy_head(h)
        v_logit = self.value_head(h).squeeze(-1)

        return F.cross_entropy(act_probs_logits, act_probs_target) + F.binary_cross_entropy_with_logits(v_logit, v_target)

    def train_agent(self, N=1000, N_batch=100, lr=2e-3, n_iter_max=90000, n_iter_networks_update=30000, weight_decay=1e-3, T=1., C=1., N_search=50):
        print("C ", C)
        print("T ", T)
        optimizer = opt.AdamW(self.parameters(), lr=lr, weight_decay=weight_decay)
        self.N_search,self.T,self.C = N_search,T,C
        Ds = []
        print("training...")
        for iter in range(n_iter_max):
            if iter%n_iter_networks_update==0:
                self.update_networks()
                D,_ = self.env.collect_data(self.policy, N, simulator=True)
                gc.collect()
                D = {k:D[k].to(dev) for k in D.keys()}
                Ds = batchify(D, N_batch)
            optimizer.zero_grad()
            loss = self.train_loss(Ds[iter%len(Ds)])
            if iter%100==0:
                sys.stdout.write("\r\033[K")
                print("\r\titeration : "+str(iter)+"  \tloss : "+str(loss.item()), end='', flush=True)
            loss.backward()
            optimizer.step()
            
        print()
        ok = self.env.eval(self, simulator=True)
        print()

        return ok

def batchify(D, n_batch):
    n = D['states'].size(0)
    p = torch.randperm(n)
    Ds = []
    n_by_batch = n//n_batch
    for i in range(n_batch):
        d = {k:D[k][p[(i*n_by_batch):((i+1)*n_by_batch)]] for k in D.keys()}
        Ds.append(d)

    return Ds


if __name__ == '__main__':
    import environment

    
    n_tests = 45
    Hs = [(i+1)*5 for i in range(10)]
    print(Hs)
    print(n_tests)
    all_results = []
    for H in Hs:
        print()
        total_win = 0
        for i in range(n_tests):
            print("Experiment "+str(i), flush=True)
            env = environment.ENV(H)
            print("H = ")
            print(env.H)
            print("b = ")
            print(env.b, flush=True)

            agent = Agent_AlphaZero(env, [256,256], N_search=50).to(dev)
            
            ok = agent.train_agent()
            print()
            print(ok, flush=True)
            print()
            total_win += ok
        print("Fraction of successes " + str(total_win/n_tests))
        all_results.append(total_win)
    print("all results")
    print(all_results)
    print([r/n_tests for r in all_results])
