import torch
import numpy as np
import pickle
import os
import torchvision
import random
cpath = os.path.dirname(__file__)

import sys
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt


#dataset
num_class = 2

def random_real_unitary_matrix(size):
    """
    Generate a random real unitary matrix of the given size.

    Args:
        size (int): The size of the unitary matrix (e.g., 3 for a 3x3 matrix).

    Returns:
        torch.Tensor: A random real unitary matrix of the given size.
    """
    # Generate a random orthogonal matrix using the Gram-Schmidt process
    matrix = torch.randn(size, size)
    Q, _ = torch.linalg.qr(matrix)

    return Q

def data_gen(feature_level, dim, noise_level, data_noise_level,num, unitary):
    n_1 = num
    n_2 = num

    dim = dim
    feature_size_1 = feature_level
    feature_size_2 = feature_level


    feature_entry_1 = random.randint(0,dim-1)
    feature_entry_2 = random.randint(0,dim-1)

    print(feature_entry_1)
    print(feature_entry_2)


    var_1 = data_noise_level*torch.ones(dim)
    var_2 = data_noise_level*torch.ones(dim)


    var_1[99] = noise_level
    var_2[96] = noise_level


    dataset_train = []

    for i in range(n_1):
        patch_1 = torch.zeros(dim)
        patch_2 = var_1*torch.randn(dim)
        patch_1[feature_entry_1] = feature_size_1
        patch_2[feature_entry_1] = 0
        patch_2[feature_entry_2] = 0

        patch_1 = torch.matmul(unitary, patch_1)
        patch_2 = torch.matmul(unitary, patch_2)

        if random.randint(0,1) == 0:
            #dataset_11.append(torch.cat((patch_1,patch_2),dim=0))
            dataset_train.append([torch.unsqueeze(torch.cat((patch_1, patch_2), dim=0),0),torch.tensor(0)])
        else:
            #dataset_11.append(torch.cat((patch_2, patch_1), dim=0))
            dataset_train.append([torch.unsqueeze(torch.cat((patch_2, patch_1), dim=0),0),torch.tensor(0)])

    for i in range(n_2):
        patch_1 = torch.zeros(dim)
        patch_2 = var_2*torch.randn(dim)
        patch_1[feature_entry_2] = feature_size_2
        patch_2[feature_entry_1] = 0
        patch_2[feature_entry_2] = 0

        patch_1 = torch.matmul(unitary, patch_1)
        patch_2 = torch.matmul(unitary, patch_2)

        if random.randint(0, 1) == 0:
        #dataset_21.append(torch.cat((patch_1, patch_2), dim=0))
            dataset_train.append([torch.unsqueeze(torch.cat((patch_1, patch_2), dim=0),0),torch.tensor(1)])
        else:
        #dataset_21.append(torch.cat((patch_2, patch_1), dim=0))
            dataset_train.append([torch.unsqueeze(torch.cat((patch_2, patch_1), dim=0),0),torch.tensor(1)])



##test dataset
    dataset_test = []
    dataset_1 = []
    dataset_2 = []

    for i in range(n_1):
        patch_1 = torch.zeros(dim)
        patch_2 = var_1*torch.randn(dim)
        patch_1[feature_entry_1] = feature_size_1
        patch_2[feature_entry_1] = 0
        patch_2[feature_entry_2] = 0

        patch_1 = torch.matmul(unitary, patch_1)
        patch_2 = torch.matmul(unitary, patch_2)

        if random.randint(0,1) ==0:
        #dataset_11.append(torch.cat((patch_1,patch_2),dim=0))
            dataset_test.append([torch.unsqueeze(torch.cat((patch_1, patch_2), dim=0),0),torch.tensor(0)])
        else:
        #dataset_11.append(torch.cat((patch_2, patch_1), dim=0))
            dataset_test.append([torch.unsqueeze(torch.cat((patch_2, patch_1), dim=0),0),torch.tensor(0)])


    for i in range(n_2):
        patch_1 = torch.zeros(dim)
        patch_2 = var_2*torch.randn(dim)
        patch_1[feature_entry_2] = feature_size_2
        patch_2[feature_entry_1] = 0
        patch_2[feature_entry_2] = 0
        patch_1 = torch.matmul(unitary, patch_1)
        patch_2 = torch.matmul(unitary, patch_2)

        if random.randint(0, 1) == 0:
        #dataset_21.append(torch.cat((patch_1, patch_2), dim=0))
            dataset_test.append([torch.unsqueeze(torch.cat((patch_1, patch_2), dim=0),0),torch.tensor(1)])
        else:
        #dataset_21.append(torch.cat((patch_2, patch_1), dim=0))
            dataset_test.append([torch.unsqueeze(torch.cat((patch_2, patch_1), dim=0),0),torch.tensor(1)])


    feature_entries = [feature_entry_1,feature_entry_2]
    return dataset_train, dataset_test, feature_entries




#x1 = torch.zeros(100)
#x2 = torch.ones(100)*1
#x = torch.unsqueeze(torch.unsqueeze(torch.cat((x1,x2),dim=0),0),0).cuda()
#print(x)

dim = 100
num_neurons = 32
num_neurons_02 = int(num_neurons)
class ConvNet(nn.Module):
    def __init__(self,num_neurons):
        super(ConvNet, self).__init__()
        self.conv1 = nn.Conv1d(1, num_neurons, kernel_size=dim, stride=dim, padding=0)
        self.relu = nn.ReLU()
        self.fc = nn.Linear(2*num_neurons, num_class)

    def forward(self, x):
        out = self.conv1(x)
        out = self.relu(out)
        out1 = out.view(out.size(0), -1)
        out = self.fc(out1)
        return out

def model_setup(num_neurons=32):
    model = ConvNet(num_neurons).cuda()

#set biases to zeros

    t1 = torch.ones(num_neurons_02)
    t2 = torch.zeros(num_neurons_02)

    weight_manual1 = torch.cat((t1,t2),0)
    weight_manual2 = torch.cat((t2,t1),0)


    model.fc.weight.data[0] = weight_manual1
    model.fc.weight.data[1] = weight_manual2



    model.conv1.bias.data.fill_(0.0)
    model.conv1.bias.requires_grad = False
    model.fc.bias.data.fill_(0.0)
    model.fc.bias.requires_grad = False
    model.fc.weight.requires_grad = False
    total_params = sum(p.numel() for p in model.parameters())

    print('Total number of parameters:', total_params)
    return model


def model_train_eval(train_loader, test_loader, feature_entries, dim, noise_scale, unitary,gpu =1):
    unitary_1 = torch.cat([torch.t(unitary), torch.zeros_like(unitary)], dim=1)
    unitary_2 = torch.cat([torch.zeros_like(unitary),torch.t(unitary)], dim=1)
    unitary = torch.cat([unitary_1,unitary_2], dim=0)
    for epoch in range(num_epochs):
        model.train()
        begin_time = time.time()
        train_loss = train_correct = train_total = 0
        aggregated_gradients = [torch.zeros_like(param) for param in model.parameters()]
        for batch_idx, (x,y) in enumerate(train_loader):

            if gpu:
                x, y = x.cuda(), y.cuda()

            for i in range(y.size(0)):
                input_exp = x[i:i + 1]
                label_exp = y[i:i + 1]
                output_exp = model(input_exp).cuda()
                loss = criterion(output_exp, label_exp)
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm(model.parameters(), 2)

                grad_aux = [p.grad.data.clone() if p.grad is not None else None for p in model.parameters()]

                for j, grad in enumerate(grad_aux):
                    if grad is not None:
                        aggregated_gradients[j] += grad

            for j, grad in enumerate(grad_aux):
                if grad is not None:
                    aggregated_gradients[j] /= y.size(0)

            for param, grad in zip(model.parameters(), aggregated_gradients):
                if grad is not None:
                    param.grad = grad

            outputs = model(x).cuda()
            loss = criterion(outputs, y)
            _, predicted = torch.max(outputs.data, 1)
            train_loss += loss
            train_total += y.size(0)
            train_correct += (predicted == y).sum().item()

            for param in model.parameters():
                if param.grad is not None:
                    noise = torch.randn(param.grad.shape) * noise_scale
                    noise = noise.cuda()
                    param.grad += noise

            optimizer.step()

        model.eval()

        test_loss = 0
        correct = 0
        total = 0

        for batch_idx, (inputs, targets) in enumerate(test_loader):
            if gpu:
                inputs, targets = inputs.cuda(), targets.cuda()
                unitary = unitary.cuda()

            #outputs = model(inputs).cuda()
            outputs = model(inputs)

            _, predicted = torch.max(outputs.data, 1)

            for i in range(len(targets)):
                x = inputs[i,0]
                label = targets[i]
                pred = predicted[i]
                output = outputs[i]
                loss = criterion(output, label).item()

                x_r = torch.matmul(unitary,x)

                if label == pred:
                    correct += 1
                total += 1
                test_loss += loss

        print(total)
        end_time = time.time()

        print('epoch', epoch)
        print('train loss', train_loss / train_total)
        print('train acc', train_correct / train_total)
        print('test loss', test_loss / total)
        print('test acc', correct/ total)

        print('Total test acc: ', )
        # print('test acc', correct/total)
        print('time', end_time - begin_time)
    return correct / total



#noise_level = 30
#feature_level = 0.2

iter =10

acc_t = torch.zeros(iter,iter)
num = 100
dim = 100
data_noise_level = 0.02
U = torch.eye(dim)
for i in range(iter):
    for j in range(iter):
        random.seed(1213)
        torch.manual_seed(1822)
        np.random.seed(2)

        feature_level = 1+i*2
        noise_scale = 0.0001+j*0.4
        noise_level = 0.02
        model = model_setup()
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.SGD(model.parameters(), lr=0.05)
        batch_size = 100
        num_epochs = 20

        dataset_train, dataset_test, feature_entries = data_gen(feature_level, dim, noise_level,data_noise_level,num,U)

        train_loader = DataLoader(dataset_train, batch_size=batch_size, shuffle=True)
        test_loader = DataLoader(dataset_test, batch_size=batch_size, shuffle=True)
        acc = model_train_eval(train_loader, test_loader, feature_entries, dim, noise_scale, U)
        acc_t[i,j] = torch.tensor(acc)
        print(acc)

print(acc_t)
torch.save(acc_t,'acc.pt')


fig, ax = plt.subplots(figsize=(8, 6))
im = ax.imshow(acc_t, cmap='viridis')
ax.set_title('Colormap Plot')
ax.set_xlabel('X')
ax.set_ylabel('Y')

plt.show()


