import argparse
import os
import time
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from MPS import *
from data_process import DataHelper
import data_process as ED
import json
from GenMPS import MPS as GenMPS
import torch as tc


parser = argparse.ArgumentParser()
parser.add_argument('--max_epochs', type=int, default=80)
parser.add_argument('--batch_size', type=int, default=80)
parser.add_argument('--embd_feature', type=int, default=300)
parser.add_argument('--optimizer', type=str, default='adam')
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--dropout', type=float, default=0.1)
parser.add_argument('--weight_decay', action='store_true')
parser.add_argument('--num_class',type=int,default=2)
parser.add_argument('--decay_epoch', type=int, default=100)
parser.add_argument('--l2_reg', type=float, default=0.0000001)
parser.add_argument('--kernel_sizes', type=int, default=1)


parser.add_argument('--KLCross', type=int, default=0)
parser.add_argument('--activate_f', type=str, default=None)
parser.add_argument('--mps_init', type=str, default='No.1')
parser.add_argument('--pre_normalize_mps', type=int, default=1)
parser.add_argument('--Lagrangian_way', type=int, default=1)
parser.add_argument('--d1', type=int, default=2)
parser.add_argument('--d2', type=int, default=10) 
parser.add_argument('--bond_dim2', type=int, default=1)  
parser.add_argument('--bond_dim', type=int, default=20)
parser.add_argument('--mps_size', type=int, default=100) 
parser.add_argument('--adaptive_mode0', type=bool, default=False)
parser.add_argument('--adaptive_mode', type=bool, default=False)
parser.add_argument('--periodic_bc', type=bool, default=True)
parser.add_argument('--cutoff', type=float, default=1e-6)
parser.add_argument('--merge_threshold', type=int, default=1040)


parser.add_argument('--save', type=str, default='Model_Test/SUBJ')    
parser.add_argument('--vocab', type=int, default=21323)
parser.add_argument('--train_len', type=int, default=9000)
parser.add_argument('--test_len', type=int, default=1000)  
parser.add_argument('--save_every', type=int, default=20)
parser.add_argument('--log_every', type=int, default=1)
parser.add_argument('--data_file', type=str, default='data/SUBJ/')
args = parser.parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def get_param_numbers(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def feature_map(imgs):
    embedded_shape = list(imgs.shape) + [2]
    embedded_data = torch.empty(embedded_shape).to(device)
    pi = math.pi/2
    sin_input = torch.sin(pi * imgs)
    cos_input = torch.cos(pi * imgs)
    embedded_data[:,:,:,0] = sin_input
    embedded_data[:,:,:,1] = cos_input

    return embedded_data

def norm(dim, type='group_norm'):
    return nn.BatchNorm1d(dim)

class CNN_Text(nn.Module):
    
    def __init__(self,embed_num, classficer,tn):
        super(CNN_Text, self).__init__()
        V = embed_num
        D = args.embd_feature
        embed_file = torch.load('{}glove_pretrain_embed.pth'.format(args.data_file))
        pretrain_embed = embed_file['pretrain']
        self.embed = nn.Embedding(V, D).from_pretrained(pretrain_embed)
        self.dropout = nn.Dropout(args.dropout)
        self.fc = classficer
        self.gensen = tn
        self.norm = norm(args.embd_feature)


    def forward(self, x):
        x = self.norm(self.embed(x).permute(0,2,1)).permute(0,2,1)
        feature_x = feature_map(x)
        x = self.dropout(F.relu(self.gensen(x, feature_x)))
        x = F.softmax(x, dim = 2)

        logit = self.fc(x)
        return logit
    
def adjust_learning_rate(optimizer, current_epoch):
    frac = float(current_epoch - args.decay_epoch) / args.max_epochs
    shrink_factor = math.pow(0.5, frac)
    for param_group in optimizer.param_groups:
        param_group['lr'] = param_group['lr'] * shrink_factor

if __name__ == '__main__':
    if not os.path.exists(args.save):
        os.makedirs(args.save)
    batch_size = args.batch_size
    max_epochs = args.max_epochs
    database_path = args.data_file
    vocab = args.vocab
    sequence_max_length = args.mps_size
    train_len = args.train_len
    text_len = args.test_len

    with open('{}wordmap.json'.format(database_path), 'r') as j:
        vocab_dict = json.load(j)
    data_helper = DataHelper(vocab_dict, sequence_max_length=sequence_max_length)

    tn = GenMPS(input_dim=args.embd_feature, output_dim=args.d2, bond_dim=args.bond_dim2, 
          adaptive_mode=args.adaptive_mode0, periodic_bc=args.periodic_bc, cutoff=args.cutoff, merge_threshold=args.merge_threshold, drout = args.dropout).to(device)

    mps = MPS(input_dim=args.mps_size - args.kernel_sizes + 1, output_dim=args.num_class, bond_dim=args.bond_dim, feature_dim = args.d2,
          adaptive_mode=args.adaptive_mode, periodic_bc=args.periodic_bc, cutoff=args.cutoff, merge_threshold=args.merge_threshold).to(device)
    model = CNN_Text(vocab, mps,tn).to(device)
    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.l2_reg)
    max_acc = 0
    panduan=False
    print("Starting...")
    
    t_start = time.time()
    train_data, train_label, test_data, test_label = data_helper.load_dataset(database_path,train_len,text_len)
    test_size = test_data.shape[0] // batch_size
    test_loader = data_helper.batch_iter(np.column_stack((test_data, test_label)), batch_size, max_epochs)
    checkpoint = torch.load(os.path.join(args.save, 'model.pth'))
    model.load_state_dict(checkpoint['state_dict'])
    model.eval()
    print("Testing...")
    num_items = 0
    accuracy = 0

    with torch.no_grad():
        for batch_idx, (data, target) in tqdm(enumerate(test_loader), total=test_size):
            data = torch.from_numpy(data).to(device).long()
            target = torch.from_numpy(target).squeeze().to(device).long()

            output, _ = model(data)             
            accuracy += torch.sum((torch.argmax(output, dim=1) == target).long()).item()
    accuracy = accuracy * 100 / test_data.shape[0]
    print("Accuracy: {}%".format(np.round(accuracy, 3)))
    model.train()