import sys,os;sys.path.append(os.getcwd())
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter 
from baselines.common.utils import onehot
from bc_utils import get_file_names, sample_batch, sample_all_batch, read_one_file, convert_to_batch
from baselines.common.model import Model
from baselines.common.wrappers import BCWrapper
    
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
advantage = False
q_path = 'output/bc_model'
n_sampled_traj = 5
latent_dim=3
    
class Policy(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.model = Model().to(device)
        self.model_eval = self.model
        if advantage:
            self.model_adv = Model().to(device)
            self.model_eval = self.model_adv
        self.update_step = 0
    def save_model(self, path):
        torch.save(self.model.state_dict(), path)
    def load_q_model(self):
        self.model_adv.load_state_dict(torch.load(q_path))
    def update(self, states, actions,tags):
        states = [torch.tensor(np.array(state),dtype=torch.float32).to(device) for state in states]
        weights = [0.1] + [0.2] * 8 + [1.2] * 43
        weights = torch.tensor(weights,dtype=torch.float32).to(device)
        actions = torch.tensor(actions,dtype=torch.long).unsqueeze(1).to(device)
        tags = torch.tensor(tags,dtype=torch.float32).to(device)
        
        z=self.model.select(states,actions,tags)#(n_sample_traj,latent_dim)
        logit_p=self.model.forward(states,tags,z)
        
        # _,logit_p = self.model(states)
        predict_actions = logit_p.argmax(dim=1)
        acc = (predict_actions == actions.squeeze(dim=1)).sum().item() / len(actions)
        loss = F.cross_entropy(logit_p, actions.squeeze(dim=1), weight = weights)
        self.model.opt.zero_grad()
        loss.backward()
        self.model.opt.step()
        
        if advantage:
            loss_adv=0
            logit_p = self.model_adv.forward_all_head(states)
            for i in range(latent_dim):
                temp_loss=F.cross_entropy(logit_p[i], actions.squeeze(dim=1), weight = weights,reduction='none')#(bs)
                for j in range(n_sampled_traj):
                    loss_adv+=(temp_loss*tags[:,j]).mean()*z[j][i].detach()
            self.model_adv.opt.zero_grad()
            loss_adv.backward()
            self.model_adv.opt.step()
        
        return loss.item(), acc
    def test(self, states, actions,tags):
        states = [torch.tensor(np.array(state),dtype=torch.float32).to(device) for state in states]
        weights = [0.1] + [0.2] * 8 + [1.2] * 43
        weights = torch.tensor(weights,dtype=torch.float32).to(device)
        actions = torch.tensor(actions,dtype=torch.long).unsqueeze(1).to(device)
        tags = torch.tensor(tags,dtype=torch.float32).to(device)
        
        with torch.no_grad():
            z=self.model.select(states,actions,tags)
            logit_p=self.model_eval.forward(states,tags,z)

            # _,logit_p = self.model(states)
            predict_actions = logit_p.argmax(dim=1)
            acc = (predict_actions == actions.squeeze(dim=1)).sum().item() / len(actions)
        
        return None, acc
    
if __name__ == '__main__':
    
    human_data_dir = "./human_data"
    TOTAL_DIRS = [
        "DATA_RELEASE_NEW_HANDLED_0",
        "DATA_RELEASE_NEW_HANDLED_1",
        "DATA_RELEASE_NEW_HANDLED_2",
    ]
    file_pointers = []
    for dir_name in TOTAL_DIRS:
        dir_name = f"{human_data_dir}/{dir_name}"
        file_names = get_file_names(dir_name)
        file_pointers += file_names
    print(file_pointers)
    n=len(file_pointers)
    n_test=int(0.1*n)+1
    train_file_pointers=file_pointers[:-n_test]
    test_file_pointers=file_pointers[-n_test:]
    tb_writer = SummaryWriter(f"./output/logs/")
    wrapper = BCWrapper({})
    policy = Policy()
    num_epochs = 100000
    for epoch in range(num_epochs):
        states_batch, action_batch,tags = sample_batch(train_file_pointers,wrapper)
        loss,train_acc = policy.update(states_batch,action_batch,tags)
        # if epoch % 10 == 0:
        #     tb_writer.add_scalar('loss', loss, epoch)
        #     tb_writer.add_scalar('acc', acc, epoch)
        if epoch % 10 == 0:
            with torch.no_grad():
                states_batch, action_batch,tags = sample_batch(test_file_pointers,wrapper)
                _,test_acc = policy.test(states_batch,action_batch,tags)
            print(f"epoch:{epoch},loss:{loss},acc:{test_acc}")
            # policy.save_model(f"./output/bc_model")
        # try: # avoid file read error
        #     states_batch, action_batch = sample_batch(file_pointers,wrapper)
        #     loss,acc = policy.update(states_batch,action_batch)
        #     if epoch % 10 == 0:
        #         tb_writer.add_scalar('loss', loss, epoch)
        #         tb_writer.add_scalar('acc', acc, epoch)
        #     if epoch % 2000 == 0:
        #         print(f"epoch:{epoch},loss:{loss},acc:{acc}")
        #         policy.save_model(f"./output/bc_model")
        # except:
        #     pass