#!/usr/bin/env python
# coding: utf-8

# In[2]:


import pickle
import random
import pandas as pd
import numpy as np
random.seed(1234)
import sys
import time
rootpath="../"
from tqdm import tqdm
import argparse

import os
# In[3]:


from collections import Counter
import json
import math
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from einops import rearrange, repeat
import random
from eval import candidate_ranking


class DataInput:
  def __init__(self, data, batch_size,model_name):

    self.batch_size = batch_size
    self.data = data
    self.epoch_size = len(self.data) // self.batch_size
    if self.epoch_size * self.batch_size < len(self.data):
      self.epoch_size += 1
    self.i = 0
    self.model_name = model_name
    self.start = 0  

  def __iter__(self):
    return self

  def __next__(self):

    if self.i == self.epoch_size:
      raise StopIteration
    if self.model_name == 'no_off_policy' or self.model_name =='no_off_policy_sigmoid' or self.model_name=='sigmoid':
      ts = []
      count = 0
      if self.start>=len(self.data):
        raise StopIteration
      for i in range(self.start, len(self.data)):
        if self.data[i][4]==1:  # 见过的
          ts.append(self.data[i])
          count+=1
          if count == self.batch_size:
            break
      self.start += self.batch_size
    else:
      ts = self.data[self.i * self.batch_size: min((self.i + 1) * self.batch_size,
                                                   len(self.data))]

    self.i += 1
    u, i, y, sl = [], [], [], []
    display = []
    for t in ts:
      u.append(t[0])
      i.append(t[2])
      y.append(t[3])
      sl.append(len(t[1]))  
      display.append(t[4])
    max_sl = max(sl)

    hist_i = np.zeros([len(ts), max_sl], np.int64)

    k = 0
    for t in ts:
      for l in range(len(t[1])):
        hist_i[k][l] = t[1][l]
      k += 1

    return self.i, (u, i, y, hist_i, sl, display)


# In[7]:


def getInfo():
    output_value = {'best_valid_recall': -999,'best_valid_ndcg':-999,
                    'this_time_valid_precision': -999, 'this_time_valid_ndcg': -999,'this_time_valid_recall': -999,
                    'test_pos_recall': -999, 'test_pos_precision': -999, 'test_pos_ndcg': -999,
                    'test_click_recall': -999, 'test_click_precision': -999, 'test_click_ndcg': -999,
                     'valid_uauc': -999, 'test_uauc': -999
                    }
    return output_value

def loadData(data_path):
    with open(data_path, 'rb') as f:
        train_set = pickle.load(f)
        valid_set = pickle.load(f)
        test_set = pickle.load(f)
        cate_list = pickle.load(f)
        user_count, item_count, cate_count = pickle.load(f)
        interaction = pickle.load(f)
        mask = pickle.load(f)  
        all_item = pickle.load(f)  
    mask_valid = [] 
    log_item = {}
    for item in interaction.keys():
        log_num_iter = math.log(interaction[item])
        temp1 = math.floor(log_num_iter)
        temp2 = log_num_iter - temp1
        log_item[item] = temp1 + round(temp2)/2 + 0.5  # 有些item变为0 了,加上0.5应该好一点
    return {'train_set':train_set,'valid_set':valid_set,'test_set':test_set,'cate_list':cate_list,
            'user_count':user_count, 'item_count':item_count, 'cate_count':cate_count,
            'interaction':interaction,'mask':mask,'mask_valid':mask_valid,
            'all_item':all_item,'log_item':log_item}


# In[8]:


def sequence_mask(lengths, max_len=None):
    if max_len is None:
        max_len = lengths.max()
    # Create a range tensor and compare it with lengths
    mask = torch.arange(max_len, device=lengths.device).expand(len(lengths), max_len) < lengths.unsqueeze(1)
    return mask


# In[9]:


class KuairecModel(torch.nn.Module):
    def __init__(self, data, h=128):
        super().__init__()
        self.user_count, self.item_count, self.cate_count, cate_list = data['user_count'], data['item_count'], data['cate_count'], torch.tensor(data['cate_list']).float()
        self.cate_list_index = torch.argmax(cate_list, dim=1)
        self.item_embedder = nn.Embedding(self.item_count, h//2)
        self.cate_embedder = nn.Embedding(self.cate_count, h//2)
        self.user_embedder = nn.Embedding(self.user_count, h)
        self.h = h
        self.user_nn_0 = nn.Sequential(
                                nn.Linear(self.h, self.h * 2), 
                                # nn.BatchNorm1d(self.h * 2),
                                nn.ReLU(),
                                # nn.Dropout(0.2),
                                nn.Linear(self.h * 2, self.h)                              
                          )
        self.user_nn_1 = nn.Sequential(
                    nn.Linear(2 * self.h, self.h),
                    # nn.BatchNorm1d(self.h),
                    nn.ReLU(),
                    # nn.Dropout(0.2),
                    nn.Linear(self.h, self.h)
        )
        self.user_item_nn = nn.Sequential(
            nn.Linear(h, 2 * h),
            # nn.BatchNorm1d(2 * h),
            nn.ReLU(),
            # nn.Dropout(0.2),
            nn.Linear(2 * h, 1)
        )
    
    def getIemb(self, item_index):
        return torch.cat([self.item_embedder(item_index), self.cate_embedder(self.cate_list_index.to(item_index.device)[item_index])], dim=1)
    
    def getUemb(self, u_index, history, hist_length):
        B, T = history.shape
        u_emb_0 = self.user_embedder(u_index)
        history = rearrange(history, 'B T -> (B T)')
        h_emb = rearrange(torch.cat([self.item_embedder(history),
                            self.cate_embedder(self.cate_list_index.to(u_index.device)[history])],
                        dim=1), '(B T) H -> B T H', B=B)
        mask = repeat(sequence_mask(hist_length, max_len=h_emb.shape[1]).float(), 'B T -> B T H', H=self.h)
        h_emb = (h_emb * mask).sum(dim=1) / (mask.sum(dim=1) + 1e-12)
        h_emb = self.user_nn_0(h_emb)
        return self.user_nn_1(torch.cat([u_emb_0, h_emb], dim=1))

    def get_action_prob(self, u_index, item_index, history, history_length, display, use_display):
        if use_display:
            u_index = u_index[display]
            history = history[display]
            history_length = history_length[display]
            item_index = item_index[display]
        u_emb = self.getUemb(u_index, history, history_length)
        i_emb = self.getIemb(item_index)
        inter_feat = rearrange(u_emb[:, None, :] * i_emb[None, :, :], 'H1 H2 D -> (H1 H2) D')
        # print(inter_feat.shape)
        u_i_scores = rearrange(self.user_item_nn(inter_feat), '(H1 H2) 1 -> H1 H2', H1=len(u_emb))
        u_i_scores = u_i_scores - torch.amax(u_i_scores, dim=-1, keepdims=True).detach()
        # print(u_i_scores)
        all_probs = torch.softmax(u_i_scores, dim=-1)
        n = all_probs.shape[0]
        ind = np.arange(n)
        return all_probs[ind, ind]
        
    def get_all_items(self, device):
      all_i_emb = torch.concatenate([
         self.item_embedder(torch.arange(self.item_count, dtype=int).to(device)),
          self.cate_embedder(self.cate_list_index.to(device))
      ], dim=-1)
      return all_i_emb
    
    def evaluate_user(self, u_index, history, history_length):
        u_emb = self.getUemb(u_index, history, history_length)
        i_emb = self.get_all_items(u_index.device)
        inter_feat = rearrange(u_emb[:, None, :] * i_emb[None, :, :], 'H1 H2 D -> (H1 H2) D')
        u_i_scores = rearrange(self.user_item_nn(inter_feat), '(H1 H2) 1 -> H1 H2', H1=len(u_emb))
        return u_i_scores.cpu().numpy()


# In[10]:

parser = argparse.ArgumentParser()
parser.add_argument('--beta_model', type=str, default='train')
parser.add_argument('--estimator', type=str, default='ips')
parser.add_argument('--est_param', type=float, default=0)
parser.add_argument('--add_noise', action='store_true')
parser.add_argument('--batch_size', default=512, type=int)
args = parser.parse_args()

data = loadData('../data/Kuai_dataset.pkl')
print("Data Loaded.")
if args.beta_model == 'train':

    beta_model = KuairecModel(data, h=128)
    n_epochs = 5
    lr = 0.005
    lr_decay = 0.9
    optimizer = torch.optim.Adam(beta_model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=lr_decay, verbose=True)
    criterion = nn.NLLLoss()
    device = 'cuda:0'
    _ = beta_model.to(device)
    batch_size = 512

    for epoch in range(n_epochs):
        for mode in ['train_set']:
            # define stat variables
            # TODO
            total_loss = 0
            total_samples = 0
            print_step = 0
            random.shuffle(data[mode])
            pbar = tqdm(DataInput(data[mode], batch_size=batch_size, model_name='capping'), total=len(data['train_set']) // batch_size)
            pbar.set_description(f"Loss={(total_loss / total_samples if total_samples > 0 else 0):.4f}")
            for _, uij in pbar:
                u, i, y, hist, sl, display = uij
                u = torch.tensor(u).long().to(device)
                i = torch.tensor(i).long().to(device)
                y = torch.tensor(y).float().to(device)
                hist = torch.tensor(hist).long().to(device)
                sl = torch.tensor(sl).long().to(device)
                display = torch.tensor(display).long().to(device)
                with torch.set_grad_enabled(mode == 'train_set'):
                    probs = beta_model.get_action_prob(u, i, hist, sl, display, use_display=False)
                loss = (-torch.log(probs + 1e-8) * display).mean()
                bsz = len(u)
                total_loss += loss.item() * bsz
                total_samples += bsz
                print_step += bsz
                if mode in ['train_set']:
                    loss.backward()
                    optimizer.step()
                    optimizer.zero_grad()
                if print_step > 1000000:
                    scheduler.step()
                    print_step = 0
                pbar.set_description(f"Loss={(total_loss / total_samples if total_samples > 0 else 0):.5f}")
            print(f"Loss value for '{mode}' = {total_loss / total_samples:.4f}")

    torch.save(beta_model.state_dict(), f'Models/beta_lr_{lr}_n_{n_epochs}_lrdecay_{lr_decay}.pth')
    beta_model_name = f'Models/beta_lr_{lr}_n_{n_epochs}_lrdecay_{lr_decay}'
else:
    beta_model_name = args.beta_model

beta_model = KuairecModel(data, h=128)
beta_model.load_state_dict(torch.load(f'Models/{beta_model_name}.pth'))

theta_model = KuairecModel(data, h=128)
n_epochs = 5
lr = 0.005
lr_decay = 1.0
optimizer = torch.optim.Adam(theta_model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=lr_decay, verbose=True)
criterion = nn.NLLLoss()
device = 'cuda:1'
_ = beta_model.to(device)
_ = theta_model.to(device)
batch_size = args.batch_size
_type = args.estimator
params = {
    "lambda": args.est_param
}

def loss_fn(_type, probs, prop, rewards, params=None, add_noise=False):
    if add_noise:
        rewards -= torch.tensor(np.random.pareto(1.1, len(rewards))).to(rewards.device)
    if _type == 'ips':
        return (probs / prop * rewards).mean()
    elif _type == 'lse':
        return torch.log(torch.exp(params['lambda'] * probs / prop * rewards).mean() + 1e-8) / params['lambda']
    elif _type == 'pm':
        renyi = (probs ** 2 / prop).mean() + 1e-8
        significance = torch.tensor(0.01)
        params_lambda = torch.sqrt(torch.log(1/significance) / (3 * len(probs) * renyi))
        weights = probs / prop
        weights = weights / (1 - params_lambda + params_lambda * weights)
        return torch.mean(rewards * weights)
    elif _type == 'es':
       return ((probs / prop ** params['lambda']) * rewards).mean()
    elif _type == 'ix':
       return (probs / (prop + params['lambda']) * rewards).mean()
    elif _type == 'snips':
       return torch.sum(probs / prop * rewards) / torch.sum(probs / prop)
    elif _type == 'os':
        w = probs / prop
        w2 = w * w
        ops_w = (params['lambda'] * w) / (w2 + params['lambda'])
        return torch.mean(ops_w * rewards)
    elif _type == 'ls':
        return -torch.mean((1 / params['lambda']) * torch.log(1 - ((params['lambda'] * rewards * probs) / (prop + 1e-8))))

# In[20]:
exps_ids = set([int(item.split('_')[2]) for item in os.listdir('results') if item.split('_')[1] == _type and item.split('_')[0] == 'theta'])
print(exps_ids)
has_the_id = True
i = -1
while has_the_id:
    i += 1
    if i not in exps_ids:
        has_the_id = False
exp_id = f'{_type}_{i}'
if args.add_noise:
    exp_id += '_add_noise'
no_improve = 0
f = open(f'results/theta_{exp_id}_lr_{lr}_n_{n_epochs}_lrdecay_{lr_decay}_lambda_{params["lambda"]}', 'w', encoding='utf-8')
best_result = 0
for epoch in range(n_epochs):
    for mode in ['train_set']:
        total_loss = 0
        total_ips_loss = 0
        total_samples = 0
        print_step = 0
        random.shuffle(data[mode])
        pbar = tqdm(DataInput(data[mode], batch_size=batch_size, model_name='capping'), total=len(data['train_set']) // batch_size)
        pbar.set_description(f"Loss={(total_loss / total_samples if total_samples > 0 else 0):.4f}")
        for _, uij in pbar:
            u, i, y, hist, sl, display = uij
            u = torch.tensor(u).long().to(device)
            i = torch.tensor(i).long().to(device)
            y = torch.tensor(y).float().to(device)
            hist = torch.tensor(hist).long().to(device)
            sl = torch.tensor(sl).long().to(device)
            display = torch.tensor(display).long().to(device)
            with torch.set_grad_enabled(mode == 'train_set'):
                probs_theta = theta_model.get_action_prob(u, i, hist, sl, display, use_display=True)
            if len(probs_theta) == 0:
               continue
            with torch.set_grad_enabled(False):
                propensities = beta_model.get_action_prob(u, i, hist, sl, display, use_display=False)[display]
            propensities = torch.maximum(propensities, 1e-3 * torch.ones(*propensities.shape).to(device))
            loss = loss_fn(_type, probs_theta, propensities, -y, params=params, add_noise=args.add_noise)
            if torch.isnan(loss).sum() > 0:
               exit()
            ips_loss = torch.mean(-(probs_theta / propensities) * y).item()
            bsz = len(u)
            total_loss += loss.item() * bsz
            total_ips_loss += ips_loss * bsz
            total_samples += display.sum()
            print_step += bsz
            if mode in ['train_set']:
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
            if print_step > 1000000:
                scheduler.step()
                with torch.no_grad():
                    valid_result, valid_pos_result,valid_user_pred,valid_uauc = candidate_ranking(theta_model, data['mask'], data['valid_set'], data['all_item'], [1, 3, 5, 10, 100], data['log_item'],[], device=device)                
                if valid_pos_result[0][0] > best_result:
                    best_result = valid_pos_result[0][0]
                    s = '-' * 50
                    f.write(s + '\n')
                    print(s)
                    s = f"STEP: {epoch}:{total_samples}"
                    f.write(s + '\n')
                    print(s)
                    s = f"VAL: Precision: {valid_pos_result[0]}, Recall: {valid_pos_result[1]}, NDCG: {valid_pos_result[2]}, Lambda: {params['lambda']}"
                    f.write(s + '\n')
                    print(s)
                    torch.save(theta_model.state_dict(), f'Models/theta_{exp_id}_lr_{lr}_n_{n_epochs}_lrdecay_{lr_decay}_lambda_{params["lambda"]}.pth')
                    with torch.no_grad():
                        test_result, test_pos_result, test_user_pred,test_uauc = candidate_ranking(theta_model, data['mask'], data['test_set'],data['all_item'],
                                                                        [1, 3, 5, 10, 100], data['log_item'],data['mask_valid'],True, device=device)
                    s = f"TEST: Precision: {test_pos_result[0]}, Recall: {test_pos_result[1]}, NDCG: {test_pos_result[2]}, Lambda: {params['lambda']}"
                    f.write(s + '\n')
                    print(s)
                    no_improve = 0
                else:
                    no_improve += 1
                    if no_improve >= 5 and epoch >= 1:
                        s = f"No improvement after {no_improve} steps. Stoping the learning process..."
                        print(s)
                        print(f"Loss={(total_loss / total_samples if total_samples > 0 else 0):.5f}, IPS Loss={(total_ips_loss / total_samples if total_samples > 0 else 0):.5f}")
                        f.write(s + '\n')
                        f.close()
                        exit()
                print_step = 0
            pbar.set_description(f"Loss={(total_loss / total_samples if total_samples > 0 else 0):.5f}, IPS Loss={(total_ips_loss / total_samples if total_samples > 0 else 0):.5f}")
        print(f"Loss={(total_loss / total_samples if total_samples > 0 else 0):.5f}, IPS Loss={(total_ips_loss / total_samples if total_samples > 0 else 0):.5f}")

f.close()