import torch

import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib as mpl
import matplotlib.pyplot as plt
import math
from complexPyTorch.complexLayers import ComplexLinear
from complexPyTorch.complexFunctions import complex_relu
import argparse

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def modrelu(z):
    return F.relu(z.abs() - 1) * torch.exp(1j * z.angle())

class Block(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(Block, self).__init__()
        self.fc1 = ComplexLinear(input_dim, hidden_dim)
        self.fc2 = ComplexLinear(hidden_dim, hidden_dim)
        self.fc3 = ComplexLinear(hidden_dim, output_dim)

    def forward(self, x):
        x = self.fc1(x)
        x = complex_relu(x)
        x = self.fc2(x)
        x = complex_relu(x)
        x = self.fc3(x)
        return x

class Symmetric(nn.Module):
    def __init__(self, input_dim, hidden_dim, symmetric_dim, output_dim):
        super(Symmetric, self).__init__()
        
        self.phi = Block(2 * input_dim, hidden_dim, symmetric_dim)
        self.rho = Block(symmetric_dim, hidden_dim, output_dim)
    
    
    def forward(self, x):        
        batch_size, input_set_dim, input_dim = x.shape
        
        #x = x.view(-1, input_dim)
        
        pairs = []
        for i in range(input_set_dim):
            for j in range(i):
                z = torch.cat([x[:,i],x[:,j]], dim = 1)
                pairs.append(self.phi(z))
        
        pairs = torch.stack(pairs, dim = 1)
        z = torch.prod(pairs, 1)        
        return self.rho(z)
    
class MultiSlaterDeterminant(nn.Module):
    def __init__(self, n, input_dim, hidden_dim, anti_dim):
        super(MultiSlaterDeterminant, self).__init__()
        self.orbitals = nn.ModuleList([Block(input_dim, hidden_dim, n) for _ in range(anti_dim)])
        
        self.input_dim = input_dim
        self.n = n
        
    def forward(self,x):        
        sds = [f(x) for f in self.orbitals]
        sds = torch.stack(sds,1)
        sds = torch.det(sds)
        return torch.sum(sds, dim = 1)
    
class MultiJastrow(nn.Module):
    def __init__(self, n, input_dim, hidden_dim, anti_dim):
        super(MultiJastrow, self).__init__()
        self.orbitals = nn.ModuleList([Block(input_dim, hidden_dim, n) for _ in range(anti_dim)])
        self.jastrows = nn.ModuleList([Symmetric(input_dim, hidden_dim, hidden_dim, 1) for _ in range(anti_dim)])
        
    def forward(self,x):
        batch_dim, set_dim, input_dim = x.shape
        
        sds = [f(x) for f in self.orbitals]
        sds = torch.stack(sds,1)
        sds = torch.det(sds)
        jas = [g(x) for g in self.jastrows]
        jas = torch.stack(jas, 1)
        jas = jas.squeeze(2)
        
        return torch.sum(sds * jas, dim = 1)
    
# def hard_function(x):
#     #Normalization?
#     n, d = x.shape
#     #r = 1. - 1./(8*n**4 + 8)
#     r = 0.95
    
#     J = 1.
#     for i in range(n):
#         for j in range(i):
#             J /= 1 - r**4 * x[i,0]**2 * x[j,0] ** 2
            
#             ###
#             #J *= r**4 - x[i,0]**2 * x[j,0] ** 2
#             ###
    
#     phi = np.zeros((n,n), dtype = 'complex_')
#     for i in range(n):
#         for j in range(n): #keep in mind this is zero indexed
#             z = x[i,0]
#             if j < n/2:
#                 phi[i,j] = r * z * (r*z) ** (n/2 - j - 1) * (1 + (r * z)**(4))**(j)
#             else:
#                 phi[i,j] = (r*z) ** (n - j - 1) * (1 + (r * z)**(4))**(j-n/2)
#     return J * np.linalg.det(phi) / np.sqrt(np.math.factorial(n))

def ComplexMSELoss(x, y):
    return torch.mean((x-y).abs()**2)

def train(model, x, y, iterations, lr=0.005):
    model.train()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    losses = []
    for i in range(iterations):
        outputs = model(x)

        optimizer.zero_grad()
        loss = ComplexMSELoss(outputs, y)
        loss.backward()
                
        optimizer.step()

        losses.append(loss.item())

    model.eval()
    return losses



if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='symmetric')

    parser.add_argument('--n', type=int, default=6, help='')
    parser.add_argument('--hidden_dim', type=int, default=20, help='')
    parser.add_argument('--anti_dim', type=int, default=15, help='')
    parser.add_argument('--iterations', type=int, default=10000)
    parser.add_argument('--samples', type=int, default=10000)
    parser.add_argument('--lr', type=float, default = 0.0025)
    
    args = parser.parse_args()
    
    d = 1
    
#     train_x = np.random.uniform(size = (args.samples, args.n, d))
#     train_x = np.exp(2 * np.pi * 1j * train_x)
#     train_x = train_x.astype(np.complex64)
#     train_y = np.array([hard_function(train_x[i]) for i in range(args.samples)])
#     train_y = train_y.astype(np.complex64)

#     train_x = torch.from_numpy(train_x).to(device)
#     train_y = torch.from_numpy(train_y).to(device)

    if args.n == 6:
        train_x = torch.load('train_x6.pt')
        train_y = torch.load('train_y6.pt')
    elif args.n==4:
        train_x = torch.load('train_x4.pt')
        train_y = torch.load('train_y4.pt')
    elif args.n==8:
        train_x = torch.load('train_x8.pt')
        train_y = torch.load('train_y8.pt')        
    else:
        train_x = torch.load('train_x2.pt')
        train_y = torch.load('train_y2.pt')
        
    train_x = train_x.to(device)
    train_y = train_y.to(device)

    if args.anti_dim > 1:
        student = MultiSlaterDeterminant(args.n, d, args.hidden_dim, args.anti_dim).to(device)
    else:
        student = MultiJastrow(args.n, d, args.hidden_dim, 1).to(device)
    losses = train(student, train_x, train_y, args.iterations, args.lr)
    print(losses[::50])
    print(min(losses))