import argparse
import os
import time
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from MPS import *
from data_process import DataHelper
import data_process as ED
import json
from GenMPS import MPS as GenMPS
import torch as tc


parser = argparse.ArgumentParser()
parser.add_argument('--max_epochs', type=int, default=80)
parser.add_argument('--batch_size', type=int, default=80)
parser.add_argument('--embd_feature', type=int, default=300)
parser.add_argument('--optimizer', type=str, default='adam')
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--dropout', type=float, default=0.1)
parser.add_argument('--weight_decay', action='store_true')
parser.add_argument('--num_class',type=int,default=2)
parser.add_argument('--decay_epoch', type=int, default=100)
parser.add_argument('--l2_reg', type=float, default=0.0000001)
parser.add_argument('--kernel_sizes', type=int, default=1)


parser.add_argument('--KLCross', type=int, default=0)
parser.add_argument('--activate_f', type=str, default=None)
parser.add_argument('--mps_init', type=str, default='No.1')
parser.add_argument('--pre_normalize_mps', type=int, default=1)
parser.add_argument('--Lagrangian_way', type=int, default=1)
parser.add_argument('--d1', type=int, default=2)
parser.add_argument('--d2', type=int, default=2) 
parser.add_argument('--bond_dim2', type=int, default=1)  
parser.add_argument('--bond_dim', type=int, default=20)
parser.add_argument('--mps_size', type=int, default=100) 
parser.add_argument('--adaptive_mode0', type=bool, default=False)
parser.add_argument('--adaptive_mode', type=bool, default=False)
parser.add_argument('--periodic_bc', type=bool, default=True)
parser.add_argument('--cutoff', type=float, default=1e-6)
parser.add_argument('--merge_threshold', type=int, default=1040)


parser.add_argument('--save', type=str, default='Model/SUBJ')    
parser.add_argument('--vocab', type=int, default=21323)
parser.add_argument('--train_len', type=int, default=9000)
parser.add_argument('--test_len', type=int, default=1000)  
parser.add_argument('--save_every', type=int, default=20)
parser.add_argument('--log_every', type=int, default=1)
parser.add_argument('--data_file', type=str, default='data/SUBJ/')
args = parser.parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def get_param_numbers(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def feature_map(imgs):
    embedded_shape = list(imgs.shape) + [2]
    embedded_data = torch.empty(embedded_shape).to(device)
    pi = math.pi/2
    sin_input = torch.sin(pi * imgs)
    cos_input = torch.cos(pi * imgs)
    embedded_data[:,:,:,0] = sin_input
    embedded_data[:,:,:,1] = cos_input

    return embedded_data

def norm(dim, type='group_norm'):
    return nn.BatchNorm1d(dim)

class CNN_Text(nn.Module):
    
    def __init__(self,embed_num, classficer,tn):
        super(CNN_Text, self).__init__()
        V = embed_num
        D = args.embd_feature
        embed_file = torch.load('{}glove_pretrain_embed.pth'.format(args.data_file))
        pretrain_embed = embed_file['pretrain']
        self.embed = nn.Embedding(V, D).from_pretrained(pretrain_embed)
        self.dropout = nn.Dropout(args.dropout)
        self.fc = classficer
        self.gensen = tn
        self.norm = norm(args.embd_feature)


    def forward(self, x):
        x = self.norm(self.embed(x).permute(0,2,1)).permute(0,2,1)
        feature_x = feature_map(x)
        x = self.dropout(F.relu(self.gensen(x, feature_x)))
        x = F.softmax(x, dim = 2)

        logit = self.fc(x)
        return logit
    
def adjust_learning_rate(optimizer, current_epoch):
    frac = float(current_epoch - args.decay_epoch) / args.max_epochs
    shrink_factor = math.pow(0.5, frac)
    for param_group in optimizer.param_groups:
        param_group['lr'] = param_group['lr'] * shrink_factor

if __name__ == '__main__':
    if not os.path.exists(args.save):
        os.makedirs(args.save)
    batch_size = args.batch_size
    max_epochs = args.max_epochs
    database_path = args.data_file
    vocab = args.vocab
    sequence_max_length = args.mps_size
    train_len = args.train_len
    text_len = args.test_len

    with open('{}wordmap.json'.format(database_path), 'r') as j:
        vocab_dict = json.load(j)
    data_helper = DataHelper(vocab_dict, sequence_max_length=sequence_max_length)

    tn = GenMPS(input_dim=args.embd_feature, output_dim=args.d2, bond_dim=args.bond_dim2, 
          adaptive_mode=args.adaptive_mode0, periodic_bc=args.periodic_bc, cutoff=args.cutoff, merge_threshold=args.merge_threshold, drout = args.dropout).to(device)

    mps = MPS(input_dim=args.mps_size - args.kernel_sizes + 1, output_dim=args.num_class, bond_dim=args.bond_dim, feature_dim = args.d2,
          adaptive_mode=args.adaptive_mode, periodic_bc=args.periodic_bc, cutoff=args.cutoff, merge_threshold=args.merge_threshold).to(device)
    model = CNN_Text(vocab, mps,tn).to(device)
    criterion = nn.CrossEntropyLoss().to(device)
    train_loss_all = []
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.l2_reg)
    epoch_time_all = []
    train_loss_speed = []
    accuracy_all = []
    print('Trained model has {} parameters'.format(get_param_numbers(model)))
    max_acc = 0
    panduan=False
    print("Starting...")
    
    for epoch in range(max_epochs):
        print('epoch:',epoch)
        if epoch > args.decay_epoch:
            adjust_learning_rate(optimizer, epoch)
        train_losses = []
        t_start = time.time()
        train_data, train_label, test_data, test_label = data_helper.load_dataset(database_path,train_len,text_len)
        train_batches = data_helper.batch_iter(np.column_stack((train_data, train_label)), batch_size, max_epochs)
        train_size = train_data.shape[0] // batch_size
        test_size = test_data.shape[0] // batch_size

        for j,batch in tqdm(enumerate(train_batches), leave=False, total=train_size):
            train_data_b,label = batch
            train_data_b = torch.from_numpy(train_data_b).to(device).long()
            label = torch.from_numpy(label).squeeze().to(device).long()
            optimizer.zero_grad()
            output, bonds = model(train_data_b)
            loss = criterion(output, label)

            if args.KLCross == 1:
                bonds = F.softmax(bonds, dim=0)
                entropy = -1 * torch.sum(bonds * torch.log(bonds))
                loss += entropy
            train_losses += [loss.item()]

            loss.backward()
            optimizer.step()
            

            end = time.time()

        epoch_time_all.append(time.time() - t_start)
        train_loss_all.append(train_losses)
        print('Train loss: {:.4f}'.format(np.mean(train_losses)))
        train_loss_speed.append(np.mean(train_losses))
        test_loader = data_helper.batch_iter(np.column_stack((test_data, test_label)), batch_size, max_epochs)

        model.eval()
        print("Testing...")
        num_items = 0
        accuracy = 0

        with torch.no_grad():
            for batch_idx, (data, target) in tqdm(enumerate(test_loader), total=test_size):
                data = torch.from_numpy(data).to(device).long()
                target = torch.from_numpy(target).squeeze().to(device).long()

                output, _ = model(data)             
                accuracy += torch.sum((torch.argmax(output, dim=1) == target).long()).item()
        accuracy = accuracy * 100 / test_data.shape[0]
        print("Accuracy: {}%".format(np.round(accuracy, 3)))
        accuracy_all.append(np.round(accuracy, 3))
        model.train()
        if accuracy > max_acc:
            max_acc = accuracy
            torch.save({'state_dict': model.state_dict(),'optimizer_state_dict':optimizer.state_dict()}, os.path.join(args.save, 'model' + '.pth'))

    torch.save({'accuracy': accuracy_all,
          'train_loss': train_loss_all,
          'epoch_time': epoch_time_all}, os.path.join(args.save, 'log.pkl'))
    print(max_acc)
    f = "acc.txt"
    with open(f,"a") as file: 
      file.write(str(max_acc) + "\n")    