import copy
import os
import sys
import argparse
import torch
import numpy as np
import pickle as pkl
from a2c_ppo_acktr.model import Policy
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import tensorboardX
from tqdm import tqdm

from r_utils import env_rob_utils, common_utils
from overcooked_ai_py.agents.agent import AgentPair
from collections import defaultdict


PREFIX = 'kd'


def forward(student_policy, data, device, v_slack):
    original_obs, disturbed_obs, act_prob, slack_act_probs, value, extend_value = data
    disturbed_obs, slack_act_probs, extend_value = \
        disturbed_obs.flatten(0, 1), slack_act_probs.flatten(0, 1), extend_value.flatten(0, 1)
    original_obs, disturbed_obs, act_prob, slack_act_probs, value, extend_value = \
        original_obs.to(device).float(), disturbed_obs.to(device).float(), act_prob.to(device).float(), \
        slack_act_probs.to(device).float(), value.to(device).float(), extend_value.to(device).float()

    # kd original
    o_value, o_actor_features, _ = student_policy.base(original_obs, None, None)
    o_probs = student_policy.dist(o_actor_features).probs

    o_kl_loss = F.kl_div(o_probs.log(), act_prob, reduction='batchmean')
    o_value_loss = F.l1_loss(o_value, value, reduction='mean')

    # kd disturbed
    d_value, d_actor_features, _ = student_policy.base(disturbed_obs, None, None)
    d_probs = student_policy.dist(d_actor_features).probs

    d_kl_loss = F.kl_div(d_probs.log(), slack_act_probs, reduction='batchmean')
    value_grace = v_slack * extend_value
    d_value_loss = F.l1_loss(d_value, extend_value, reduction='none')
    d_value_loss = torch.clamp(d_value_loss - value_grace, 0).mean()

    loss_dict = {
        'o_kl_loss': o_kl_loss.detach().cpu().item(),
        'o_value_loss': o_value_loss.detach().cpu().item(),
        'd_kl_loss': d_kl_loss.detach().cpu().item(),
        'd_value_loss': d_value_loss.detach().cpu().item()
    }
    loss = o_kl_loss + o_value_loss + d_kl_loss + d_value_loss
    return loss_dict, loss


def kd_fine_tune(target_policy, train_loader, val_loader, train_args):
    agent_path = train_args.agent_path
    agent_name = train_args.agent_name
    exp_id = train_args.exp_id
    device = train_args.device
    n_epoch = train_args.n_epoch
    layout = train_args.layout
    lr = train_args.lr
    T = train_args.T
    v_slack = train_args.v_slack

    logger = tensorboardX.SummaryWriter(logdir=f'./log/{layout}/{agent_name}_{exp_id}')
    os.makedirs(os.path.join(agent_path, PREFIX), exist_ok=True)
    save_fp = os.path.join(agent_path, PREFIX, f"{exp_id}.pt")

    student_policy = copy.deepcopy(target_policy)
    student_policy.train()
    student_policy.to(device)

    best_loss = 1e19
    best_epoch = -1
    dbg = True

    optimizer = torch.optim.Adam(student_policy.parameters(), lr=lr)
    for epoch in range(n_epoch):
        train_epoch_loss = defaultdict(list)
        val_epoch_loss = defaultdict(list)
        student_policy.train()
        for data in tqdm(train_loader):
            optimizer.zero_grad()

            loss_dict, loss = forward(student_policy, data, device, v_slack)
            for k, v in loss_dict.items():
                train_epoch_loss[k].append(v)

            loss.backward()
            optimizer.step()

        student_policy.eval()
        for data in tqdm(val_loader):
            loss_dict, loss = forward(student_policy, data, device, v_slack)
            for k, v in loss_dict.items():
                val_epoch_loss[k].append(v)
            if loss.detach().cpu().item() < best_loss:
                best_epoch = epoch
                torch.save(student_policy, save_fp)

        train_epoch_loss = {k: np.mean(v) for k, v in train_epoch_loss.items()}
        val_epoch_loss = {k: np.mean(v) for k, v in val_epoch_loss.items()}
        for k, v in train_epoch_loss.items():
            logger.add_scalar("train_" + k, v, epoch)
            logger.add_scalar("val_" + k, val_epoch_loss[k], epoch)
        print(epoch, train_epoch_loss, val_epoch_loss)

    return


def main(train_args):
    agent_path = train_args.agent_path
    horizon = train_args.horizon
    data_n = train_args.data_n
    layout = train_args.layout
    st_fn = train_args.st_fn
    st_fp = os.path.join(agent_path, 'start_states', st_fn)
    bs = train_args.batch_size
    device = train_args.device
    verbose = train_args.verbose
    T = train_args.T
    slack_T = train_args.slack_T
    val_p = train_args.val_p

    target_agent, _, _ = env_rob_utils.load_saved_apag_agent(agent_path, horizon, deterministic=False, device=device)

    mdp_params = {
        "layout_name": layout,
    }
    env_params = {
        "horizon": horizon
    }
    ae = env_rob_utils.get_base_ae(mdp_params, env_params, None, None)

    rollouts = common_utils.prepare_data(data_n, target_agent, ae, mode='sp', verbose=verbose)

    obss, act_probs, slack_act_probs, values = common_utils.process_data_for_kd(rollouts, target_agent.actor_critic, ae, T, slack_T, bs, device)

    start_states = pkl.load(open(st_fp, 'rb'))   # list of start states
    all_delta_s = common_utils.get_disturbed_obss(start_states, ae)
    all_delta_s = torch.from_numpy(all_delta_s)

    all_dataset = KDDataset(obss, act_probs, slack_act_probs, values, all_delta_s)
    n_data = all_dataset.n_data
    train_size = int(n_data * (1 - val_p))
    val_size = n_data - train_size

    train_dataset, val_dataset = torch.utils.data.random_split(all_dataset, [train_size, val_size])
    train_loader = DataLoader(train_dataset, batch_size=bs, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=bs, shuffle=True)

    kd_fine_tune(target_agent.actor_critic, train_loader, val_loader, train_args)
    return


class KDDataset(Dataset):
    def __init__(self, obss, act_probs, slack_act_probs, values, all_delta_s):
        self.obss = obss
        self.act_probs = act_probs
        self.slack_act_probs = slack_act_probs
        self.values = values
        self.all_delta_s = all_delta_s
        self.n_data = self.obss.shape[0]
        self.n_adv_states = self.all_delta_s.shape[0]

        print(self.n_adv_states)

    def __len__(self):
        return self.n_data

    def __getitem__(self, idx):
        original_obs = self.obss[idx]
        act_prob = self.act_probs[idx]
        slack_act_probs = self.slack_act_probs[idx].unsqueeze(0)
        slack_act_probs = slack_act_probs.repeat((self.n_adv_states, 1))
        value = self.values[idx]
        o_value = value.unsqueeze(0)
        o_value = o_value.repeat((self.n_adv_states, 1))
        disturbed_obs = self.all_delta_s.clone()
        disturbed_obs += original_obs

        data = [original_obs, disturbed_obs, act_prob, slack_act_probs, value, o_value]
        return data


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--agent_path', type=str)
    parser.add_argument('--agent_name', type=str)
    parser.add_argument('--data_n', type=int, default=10)
    parser.add_argument('--horizon', type=int, default=800)
    parser.add_argument('--batch_size', type=int, default=800)
    parser.add_argument('-d', '--device', type=str, default='cuda:0')
    parser.add_argument('-l', '--layout', type=str)
    parser.add_argument('--st_fn', type=str)
    parser.add_argument('--exp_id', default='')
    parser.add_argument('--warm_start', default=False)
    parser.add_argument('--verbose', default=False)

    parser.add_argument('-T', type=float, default=1.5)
    parser.add_argument('--slack_T', type=float, default=2)
    parser.add_argument('--val_p', type=float, default=0.3)
    parser.add_argument('--n_epoch', type=int, default=20)
    parser.add_argument('--lr', type=float, default=5e-4)
    parser.add_argument('--v_slack', type=float, default=0.02)

    args = parser.parse_args()
    torch.set_num_threads(4)

    main(args)
