import sys
import numpy as np
import torch
from torchvision import transforms
import torch.nn as nn
import torch.optim as optim
import math
import matplotlib.pyplot as plt
import argparse
import os
import random
from pytorch_lightning.utilities.seed import seed_everything
from torch.utils.data import IterableDataset
import torch.distributed as dist
import torch.multiprocessing as mp

# Define encoder network
class CNNEncoder(nn.Module):
    def __init__(self,k):
        super(CNNEncoder, self).__init__()
        #kernel=5,H2=(H1-K+2Pad+Stride)/Stride
        k=k
        s=4
        p=(k-s)//2
        hiden_size =512
        self.conv1 = nn.Conv2d(6, 8,kernel_size=k,stride=s,padding=p)#256*256
        self.bn1 = nn.BatchNorm2d(8)
        self.conv2 = nn.Conv2d(8, 64,kernel_size=k,stride=s,padding=p)#64*64
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 256,kernel_size=k,stride=s,padding=p)#16*16
        self.bn3 = nn.BatchNorm2d(256)
        self.conv4 = nn.Conv2d(256, hiden_size,kernel_size=k,stride=s,padding=p)#4*4
        self.bn4 = nn.BatchNorm2d(hiden_size)
        # self.fc1 = nn.Linear(hiden_size, 2*hiden_size)
        # self.fc2 = nn.Linear(2*hiden_size, hiden_size)
        self.relu=nn.ReLU()
    def forward(self, x):
        x = self.bn1(self.relu(self.conv1(x.squeeze(0))))
        x = self.bn2(self.relu(self.conv2(x)))
        x = self.bn3(self.relu(self.conv3(x)))
        x = self.bn4(self.conv4(x))
        x = torch.mean(x.view(x.shape[0],x.shape[1],-1),dim=-1)
        return x.view(-1,x.shape[1])

#Discriminator
class oldDiscriminator(nn.Module):
    def __init__(self, input_dim):
        super(Discriminator, self).__init__()
        self.input_dim = input_dim

        self.fc = nn.Sequential(
            nn.Linear(self.input_dim, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.fc(x)
#Discriminator
class Discriminator(nn.Module):
    def __init__(self, input_dim):
        super(Discriminator, self).__init__()
        self.input_dim = input_dim

        self.fc = nn.Sequential(
            nn.Linear(self.input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.fc(x)
# Define Contrastive Loss function
class ContrastiveLoss(nn.Module):
    def __init__(self, margin=0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin
        self.cossim = torch.cosine_similarity
    def forward(self, anchor, positive, negative):
        loss_contrastive=torch.mean(positive_rate*(1-self.cossim(anchor, positive))+(2-positive_rate)*torch.abs(self.cossim(anchor, negative)))
        return loss_contrastive

def stand_normalize2d(input_matrix):
    normalized_matrix = torch.empty_like(input_matrix)
    for i in range(input_matrix.shape[0]):
        mean = torch.mean(input_matrix[i])
        std = torch.std(input_matrix[i])
        normalized_matrix[i]=(input_matrix[i]-mean)/std
    return normalized_matrix
def stand_normalize(input_matrix):
    for i in range(input_matrix.shape[0]):
        for j in range(input_matrix.shape[1]):
            mean = torch.mean(input_matrix[i][j])
            std = torch.std(input_matrix[i][j])
            input_matrix[i][j]=(input_matrix[i][j]-mean)/std
    return input_matrix
def load_plot_hist(image,low,high):
    clip_image=np.clip(image.flatten().numpy(),low,high)
    plt.hist(clip_image,bins=np.linspace(low,high, 1000),range=(low,high))
    plt.savefig("/home/byzeng/project/weights-search/jpgs/testfft.jpg_h.jpg")


def load_normalize_testdata(path):
    anchor_llama=torch.from_numpy(np.load(path))
    for i in range(anchor_llama.shape[0]):
        mean = torch.mean(anchor_llama[i])
        std = torch.std(anchor_llama[i])
        anchor_llama[i]=(anchor_llama[i]-mean)/std
    return anchor_llama.unsqueeze(0)
class MyDataset(torch.utils.data.Dataset):
    def __init__(self,batch_size,batch_len,noise):
        super(MyDataset, self).__init__() 
        # self.sizes = [(batch_size,4,1024, 1024), (batch_size,4,2048, 2048),(batch_size,4,4096,4096), \
        #               (batch_size,4,5120, 5120), (batch_size,4,6144, 6144),(batch_size,4,7168,7168)]        
        # self.sizes = [(batch_size,4,4096,4096)]
        self.sizes = (batch_size,6,4096,4096)
        self.len=batch_len
        self.noise=noise
        self.i=0  
        # self.tensor = torch.zeros((4,4096,4096))
    def __len__(self):
        return self.len
    
    def __getitem__(self, idx):
        q=torch.randn(self.sizes)
        k=torch.randn(self.sizes)
        x=torch.randn(self.sizes)
        y=torch.randn(self.sizes)
        anchor_sample = x.matmul(q).matmul(k).matmul(y)
        qp=q+0.1*torch.randn(self.sizes)
        kp=k+0.1*torch.randn(self.sizes)
        xp=x+self.noise*torch.randn(self.sizes)
        yp=xp.permute(0, 1, 3, 2)
        # yp=y+self.noise*torch.randn(self.sizes)
        positive_sample = xp.matmul(qp).matmul(kp).matmul(yp)
        qn=torch.randn(self.sizes)
        kn=torch.randn(self.sizes)
        xn=torch.randn(self.sizes)
        yn=xn.permute(0, 1, 3, 2)
        negative_sample = xn.matmul(qn).matmul(kn).matmul(yn)
        return stand_normalize(anchor_sample), stand_normalize(positive_sample), stand_normalize(negative_sample)
class DisDataset(torch.utils.data.Dataset):
    def __init__(self,batch_size,batch_len):
        super(DisDataset, self).__init__() 
        # self.sizes = [(batch_size,4,1024, 1024), (batch_size,4,2048, 2048),(batch_size,4,4096,4096), \
        #               (batch_size,4,5120, 5120), (batch_size,4,6144, 6144),(batch_size,4,7168,7168)]        
        # self.sizes = [(batch_size,4,4096,4096)]
        self.sizes = (batch_size,6,4096,4096)
        self.len=batch_len
        self.i=0  
        # self.tensor = torch.zeros((4,4096,4096))
    def __len__(self):
        return self.len
 
    def __getitem__(self, idx):
        q=torch.randn(self.sizes)
        k=torch.randn(self.sizes)
        x=torch.randn(self.sizes)
        y=x.permute(0, 1, 3, 2)
        anchor_sample = x.matmul(q).matmul(k).matmul(y)
        return stand_normalize(anchor_sample)

def parse_args():
    parser = argparse.ArgumentParser(description='Encoder Training Script')
    parser.add_argument('--batchsize', type=int, default=10, help='Batch size for training')
    parser.add_argument('--batchlen', type=int, default=10, help='Batch size for training')
    # parser.add_argument('--testtime', type=int, default=10, help='Batch size for training')
    # parser.add_argument('--outputpath', type=str, default='/home/byzeng/project/weights-search/goodencoders/', help='Path to save trained encoder')
    parser.add_argument('--cudadevice', type=str, default='4', help='CUDA device(s) to use (e.g., "0,1")')
    parser.add_argument('--learningrate', type=float, default=0.0001, help='Learning rate for optimization')
    parser.add_argument('--noise', type=float, default=0.3, help='Learning rate for optimization')
    parser.add_argument('--k', type=int, default=48, help='Learning rate for optimization')
    parser.add_argument('--positiverate', type=float, default=1.6, help='Learning rate for optimization')
    return parser.parse_args()        
if __name__ == "__main__":
    # 定义一个文件路径来保存打印输出
    # log_file_path = "/home/byzeng/project/weights-search/encoder/encoder_gan6_output.log"
    # 将标准输出重定向到文件
    # sys.stdout = open(log_file_path, "w")
    torch.autograd.set_detect_anomaly(True)
    args = parse_args()
    batch_size = args.batchsize
    cuda_device = args.cudadevice
    learning_rate = args.learningrate
    noise = args.noise
    k = args.k
    positive_rate = args.positiverate
    testtime = args.testtime
    #设置随机数种子
    seed_everything(100)
    # os.environ['CUDA_VISIBLE_DEVICES'] = '5'
    os.environ['CUDA_VISIBLE_DEVICES'] = cuda_device
    os.environ['NCCL_P2P_DISABLE'] = '1'
    os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
    # 定义模型的参数
    hiden_size =512
    # dropout = 0
    # Define hyperparameters
    # learning_rate = 0.0001
    margin = 0.2
    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        # model = nn.DataParallel(model).cuda()
    # Initialize encoder and optimizer
    encoder = CNNEncoder(k).cuda()
    # encoder = torch.load('/home/byzeng/project/weights-search/encoder/encoder_3gan4096_512_6.pth')
    optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,eta_min=0.000005,T_max=1500)
    discriminator=Discriminator(512).cuda()
    optimizer_D = optim.Adam(discriminator.parameters(), lr=learning_rate)
    scheduler_D = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_D,eta_min=0.000005,T_max=1500)

    # Initialize data loader
    dataset = MyDataset(batch_size=batch_size,batch_len=args.batchlen,noise=noise)
    # dataset = RandomSizeDataset()
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=1)
    Disdataset = DisDataset(batch_size=batch_size,batch_len=args.batchlen)
    # dataset = RandomSizeDataset()
    disdataloader = torch.utils.data.DataLoader(Disdataset, batch_size=1)
    criterion= ContrastiveLoss(positive_rate)
    criterion_D=nn.BCELoss()


    # Training loop


    # Train the encoder
    for epoch in range(100):
        running_loss = 0.0
        encoder.eval()
        discriminator.train()
        for i, anchor in enumerate(disdataloader):
            # optimizer_D.zero_grad()
            real_samples = torch.randn(batch_size, 512) 
            real_labels = torch.ones(batch_size, 1)
            fake_labels = torch.zeros(batch_size, 1)
            
            optimizer_D.zero_grad()
            real_outputs = discriminator(real_samples.cuda())
            loss_real = criterion_D(real_outputs, real_labels.cuda())

            anchor_embedding = encoder(anchor.cuda())
            fake_samples = stand_normalize2d(anchor_embedding)
            fake_outputs = discriminator(fake_samples)
            loss_fake = criterion_D(fake_outputs, fake_labels.cuda())
            loss_D = loss_real + loss_fake
            loss_D.backward()
            optimizer_D.step()
            scheduler_D.step()
            running_loss += loss_D.item()
            print('[%d, %5d] discriminator loss: %.3f' %
            ( epoch, i + 1, running_loss ))
            running_loss = 0.0
        discriminator.eval()
        encoder.train()
        for i, (anchor, positive, negative) in enumerate(dataloader):
            optimizer.zero_grad()
            real_labels = torch.ones(batch_size, 1)
            anchor_embedding = encoder(anchor.cuda())
            positive_embedding = encoder(positive.cuda())
            negative_embedding = encoder(negative.cuda())
            loss_C = criterion(anchor_embedding, positive_embedding, negative_embedding)
            
            fake_samples = stand_normalize2d(anchor_embedding)
            fake_outputs = discriminator(fake_samples)
            loss_D = criterion_D(fake_outputs, real_labels.cuda())
           
            loss = loss_C+loss_D/2
            loss.backward()
            optimizer.step()
            scheduler.step()
            running_loss += loss.item()
            print('[%d, %5d] encoder loss: %.3f loss_c:%.3f' %
            ( epoch, i + 1, running_loss,loss_C.item() ))
            running_loss = 0.0
            # if (i+1) % 10 == 0:    # print every 100 batches
            #     scheduler.step()
            # if loss.item()<0.03 or i>2000:
            #     break
        # if epoch > swa_start:
        #       swa_model.update_parameters(encoder)
        #       swa_scheduler.step()
        if (epoch+1) % 2 == 0:
            torch.save(encoder, "/home/byzeng/project/weights-search/newencoder/encoder_6gan4096_"+str(hiden_size)+"_"+str(epoch)+".pth")
            # torch.save(encoder, "/home/byzeng/project/weights-search/encoder/encoder_5gan4096_"+str(hiden_size)+"_"+str(epoch)+".pth")
            # torch.save(encoder, "/home/byzeng/project/weights-search/encoder/encoder_k36_l16_b10"+"_"+str(epoch)+".pth")
            # torch.save(encoder, "/home/byzeng/project/weights-search/encoder/encoder_k48_l16_b10"+"_"+str(epoch)+".pth")
            # torch.save(encoder, "/home/byzeng/project/weights-search/encoder/encoder_k48_l16"+"_"+str(epoch)+".pth")
        
    # Save the encoder
    # torch.save(encoder, "/home/byzeng/project/weights-search/encoder/encoder_gan6_"+str(hiden_size)+".pth")
    # torch.save(swa_model, "/home/byzeng/project/weights-search/swa_model_new_"+str(hiden_size)+".pth")
    print('Finished training')
    # 恢复标准输出
    sys.stdout.close()
    sys.stdout = sys.__stdout__
    # Test the encoder