import torch
import numpy as np
from torch_networks import MLP, copy_model
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import TensorDataset, DataLoader
import time
import random
from os.path import join
from os import makedirs


def poleval(data_fname,
            data_id,
         save_folder,
         max_steps=50_000,
         mini_batch_size=256,
         lr=3e-4,
         hsize=(256, 256, 256),
         mode = 'DouBel', # or 'FQI', 'BRM', 'TTS', 'BC'
         copy_freq = 100,
         lam_brm = .1,
         seed=0,
         device='cpu',
        ):

    ndat = 5000
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.set_num_threads(1)
    discount = .99
    
    # loading data
    data = np.load(data_fname, allow_pickle=True).item()

    obs = torch.tensor(data['obs'][:ndat, :], dtype=torch.float, device=device)
    act = torch.tensor(data['act'][:ndat, :], dtype=torch.long, device=device)
    rwd = torch.tensor(data['rwd'][:ndat, :], dtype=torch.float, device=device)
    nobs = torch.tensor(data['nobs'][:ndat, :], dtype=torch.float, device=device)
    ter = torch.tensor(data['terminated'][:ndat, :], dtype=torch.float, device=device)
    true_v = torch.tensor(data['ret'][:ndat, :], dtype=torch.float, device=device)

    sdim = obs.shape[1]
    v_fun = MLP([sdim] + list(hsize) + [1], torch.nn.ReLU()).to(device)
    old_v_fun = MLP([sdim] + list(hsize) + [1], torch.nn.ReLU()).to(device)
    copy_model(old_v_fun, v_fun)

    fname = join(save_folder, f'{data_id}', mode, f'cf{copy_freq}', f'lam{lam_brm}', f'seed{seed}')
    print('log dir', fname)
    makedirs(fname, exist_ok=True)
    logger = SummaryWriter(fname)

    mse = torch.nn.MSELoss()
    grad_step = 0
    print('number of parameters of the v_fun', sum(p.numel() for p in v_fun.parameters()))
    n_feat = hsize[-1]
    lr_lin_mult = 1.
    if mode == 'TTS':
        lr_lin_mult = 10.

    r_fun = torch.nn.Linear(n_feat, 1).to(device)
    feat_fun = torch.nn.Linear(n_feat, n_feat).to(device)
    aux_lin = torch.nn.Linear(n_feat, 1, bias=True).to(device)
    optim = torch.optim.Adam([ {'params': [p for f in v_fun.f[:-1] for p in f.parameters()], 'lr': lr},
                               {'params': [p for p in v_fun.f[-1].parameters()], 'lr': lr * lr_lin_mult},
                               {'params': aux_lin.parameters(), 'lr': lr},
                               {'params': [p for p in r_fun.parameters()] + [p for p in feat_fun.parameters()], 'lr': lr}
                             ])

    train_set = DataLoader(TensorDataset(obs, act, rwd, nobs, ter), shuffle=True, batch_size=mini_batch_size, drop_last=True)
    save_frequency = 50
    while grad_step < max_steps:
        for o, a, r, no, t in train_set:
            ## BRM
            if mode == 'BRM':
                v = v_fun(o)
                vnext = v_fun(no)
                loss = (r + discount * (1 - t) * vnext - v).pow(2).mean()

            ## BRM detach w
            elif mode == 'DouBel':
                fo = v_fun.get_features(o)
                fno = v_fun.get_features(no)
                w = v_fun.f[-1].weight.detach().t()
                b = v_fun.f[-1].bias.detach()
                v = fo @ w + b
                vnext = fno @ w + b
                loss_br = (r + discount * (1 - t) * vnext - v).pow(2).mean()
                with torch.no_grad():
                    vnext_on = old_v_fun(no)
                loss = mse(v_fun.lin(fo), r + discount * (1 - t) * vnext_on) + lam_brm * loss_br

            elif mode == 'TTS':
                fo = v_fun.get_features(o)
                fno = v_fun.get_features(no)
                v_aux = aux_lin(fo)
                vnext_aux = aux_lin(fno)
                loss_br = (r + discount * (1 - t) * vnext_aux - v_aux).pow(2).mean()
                with torch.no_grad():
                    vnext_on = old_v_fun(no)
                loss = mse(v_fun.lin(fo.detach()), r + discount * (1 - t) * vnext_on) + lam_brm * loss_br


            ## FVI alone
            elif mode == 'FQI':
                with torch.no_grad():
                    vnext_on = old_v_fun(no)
                loss = mse(v_fun(o), r + discount * (1 - t) * vnext_on)

            # FVI + Model
            elif mode == 'BC':
                fo = v_fun.get_features(o)
                rpred = r_fun(fo)
                fpred = feat_fun(fo)
                loss_r = mse(rpred, r)
                with torch.no_grad():
                    ofno = old_v_fun.get_features(no)
                    vnext_on = old_v_fun.lin(ofno)

                loss_f = (fpred - ofno.detach()).pow(2).sum(1).mean()

                loss_v = mse(v_fun.lin(fo), r + discount * (1 - t) * vnext_on)
                loss = loss_v + loss_r + loss_f

            optim.zero_grad()
            loss.backward()
            optim.step()
            if grad_step % save_frequency == 0:
                with torch.no_grad():
                    fo = v_fun.get_features(obs)
                    fno = v_fun.get_features(nobs)
                    vtargs = rwd + discount * (1 - ter) * v_fun.lin(fno)
                    nb_dead_features = (fo.std(0) == 0).sum().item()
                    logger.add_scalar('stat/dead_features', nb_dead_features, grad_step)
                    vvals = v_fun.lin(fo)
                    bell_error = mse(vvals, vtargs)
                    logger.add_scalar('loss/bell_error', bell_error.item(), grad_step)
                    mse_trueq = mse(vvals, true_v)
                    logger.add_scalar('loss/mse_trueq', mse_trueq.item(), grad_step)
            grad_step += 1
            if grad_step % copy_freq == 0:
                copy_model(old_v_fun, v_fun)


