import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import pickle
import torch.nn.functional as F
import os
from torch.utils.data import random_split
import numpy as np
import random
import torch
import torchvision
import copy
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data import Dataset
# Set random seed for Python


# Check if CUDA (GPU support) is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = "cpu"
print("Using device:", device)

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
 
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.enabled = True
    
    os.environ['PYTHONHASHSEED'] = str(seed)
    os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'

    torch.use_deterministic_algorithms(True)

SEEED = 1
class CustomDataset(Dataset):
    def __init__(self, mnist_dataset):
        np.random.seed(SEEED)
        random.seed(SEEED)
        self.mnist_dataset = mnist_dataset
        self.num_classes = 10
        self.transform = transforms.Compose([
            transforms.ToTensor()
        ])
        self.partitioned_indices = self.partition_dataset()
        self.post_load_setup()

    def __len__(self):
        return len(self.final_dataset)

    def restrict_dataset_size(self):
        final_dataset = []
        label_counts = [0] * self.num_classes
        # Iterate through the dataset and collect samples
        for idx in range(len(self.mnist_dataset)):
            image, label = self.mnist_dataset[idx]
            label_one_hot = torch.zeros(self.num_classes)
            

            pred_label = random.randint(0, self.num_classes - 1)
            label_one_hot[pred_label] = 1
            if pred_label == label:
                y_label = (label + 1) % self.num_classes
            else:
                y_label = (label - 1) % self.num_classes

            y_image, _ = self.get_image_with_label(y_label)

            # Add the sample to the final dataset if the label count is less than 2000
            if label_counts[label] < 6000:
                final_dataset.append((image, label_one_hot, y_image, label, y_label))
                label_counts[label] += 1

            # Break if we have collected 60,000 samples
            if sum(label_counts) == 60000:
                break

        return final_dataset

    def __getitem__(self, idx):
        return self.final_dataset[idx]

    def partition_dataset(self):
        partitioned_indices = [[] for _ in range(self.num_classes)]
        for idx, (_, label) in enumerate(self.mnist_dataset):
            partitioned_indices[label].append(idx)
        return partitioned_indices

    def get_image_with_label(self, label):
        indices = self.partitioned_indices[label]
        idx = random.choice(indices)
        image, true_label = self.mnist_dataset[idx]
        return image, true_label

    def post_load_setup(self):
        self.final_dataset = self.restrict_dataset_size()

    def __setstate__(self, state):
        self.__dict__.update(state)
        self.post_load_setup()


class PolicyNetwork(nn.Module):
    def __init__(self):
        super(PolicyNetwork, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(32 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(-1, 32 * 7 * 7)
        x = F.relu(self.fc1(x))
        x = nn.Softmax(dim=1)(self.fc2(x))
        return x

# Define the CNN model
class CNN(nn.Module):
    def __init__(self,seed):
        super(CNN, self).__init__()
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        self.conv1 = nn.Conv2d(2, 16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(32 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x, y):
        x = torch.cat((x, y), dim=1)  # Concatenate X and Y along the channel dimension
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(-1, 32 * 7 * 7)
        x = F.relu(self.fc1(x))
        x = nn.Softmax(dim=1)(self.fc2(x))
        return x

transform = transforms.Compose([
        transforms.ToTensor()
    ])

def custom_collate_fn(batch):
    # Apply transformations to the images
    transform = transforms.Compose([
        transforms.ToTensor()
    ])

    # Collect images, labels, y_images, and labels into separate lists
    images, labels_one_hot, y_images, labels, y_label = zip(*batch)

    # Convert images and y_images to tensors
    images = [transform(image) for image in images]
    y_images = [transform(y_image) for y_image in y_images]

    # Convert lists to tensors
    images = torch.stack(images)
    labels_one_hot = torch.stack(labels_one_hot)
    y_images = torch.stack(y_images)

    return images, labels_one_hot, y_images, labels, y_label

def custom_collate_fn_image(batch):
    # Apply transformations to the images
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    # Collect images, labels, y_images, and labels into separate lists
    images, labels_one_hot, y_images, labels, y_label = zip(*batch)

    # Convert images and y_images to tensors
    images = [transform(image) for image in images]
    y_images = [transform(y_image) for y_image in y_images]

    # Convert lists to tensors
    images = torch.stack(images)
    labels_one_hot = torch.stack(labels_one_hot)
    y_images = torch.stack(y_images)

    return images, labels_one_hot, y_images, labels, y_label


train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

num_repeats = 1
custom_train_dataset = CustomDataset(train_dataset)
# Repeat the dataset
repeated_datasets = [custom_train_dataset] * num_repeats

concatenated_data = []
for dataset in repeated_datasets:
    concatenated_data.extend(dataset)

def eval_model_set(model, dataloader):
    total_samples = 0
    correct_predictions = 0
    model.eval()
    with torch.no_grad():
        for images, _, _, labels, _ in dataloader:
            # Move data to the device (CPU or GPU) used by the model
            images = images.to(device)
            labels = torch.tensor(labels)
            labels = labels.to(device)

            # Forward pass
            outputs = model(images)

            # Get predicted labels
            _, predicted_labels = torch.max(outputs, 1)

            # Update total samples and correct predictions counts
            total_samples += labels.size(0)
            correct_predictions += (predicted_labels == labels).sum().item()

    # Calculate accuracy
    accuracy = (correct_predictions+0.1*N1+0.1*N2) / (total_samples+N1+N2)
    return accuracy

def eval_model_test_set(model, dataloader):
    total_samples = 0
    correct_predictions = 0
    model.eval()
    with torch.no_grad():
        for images, labels in dataloader:
            # Move data to the device (CPU or GPU) used by the model
            images = images.to(device)
            labels = torch.tensor(labels)
            labels = labels.to(device)

            # Forward pass
            outputs = model(images)

            # Get predicted labels
            _, predicted_labels = torch.max(outputs.data, 1)

            # Update total samples and correct predictions counts
            total_samples += labels.size(0)
            correct_predictions += (predicted_labels == labels).sum().item()

    # Calculate accuracy
    accuracy = correct_predictions / total_samples
    return accuracy

def modify_training_set(model, dataloader,th,th_w):
    modified_dataset = []
    with torch.no_grad():
        TP = 0
        FP = 0
        TN = 0
        FN = 0
        for images, labels_one_hot, y_images, labels, y_label in dataloader:

            # image: x
            # y_image: singal
            # labels_one_hot: action

            images = images.to(device)
            y_images = y_images.to(device)
            # Forward pass to compute predictions
            outputs = model(images,y_images)

            outputs = outputs.cpu()
            reward = []
            for idx in range(len(labels)):
                # Extract the scalar label for the current sample
                label = labels[idx]
                labels_one_hot_act = labels_one_hot[idx]
                labels_one_hot_act = list(labels_one_hot_act).index(1)
                y_label_s = y_label[idx]
                num_classes = 10
                # Convert scalar label to one-hot encoding
                one_hot_label = torch.zeros(1, num_classes)
                one_hot_label.scatter_(1, torch.tensor([[label]]), 1)

                # Extract predicted probability for true label
                predicted_prob = outputs[idx][labels_one_hot_act]

                # Compute reward based on predicted probability
                
                if th-th_w/2 < predicted_prob < th+th_w/2:
                    r = (predicted_prob - (th-th_w/2)) * 1/th_w
                elif predicted_prob >= th+th_w/2:
                    r = 1
                else:
                    r = 0

                reward.append(r)

                if (y_label_s == (label + 1) % num_classes):
                    true_r = 1
                else:
                    true_r = 0

                if ((true_r > 0) and (r > 0)):
                    TP += 1
                elif ((true_r > 0) and (r == 0)):
                    FN += 1
                elif ((true_r == 0) and (r > 0)):
                    FP += 1
                else:
                    TN += 1

                # Convert the list of rewards to a tensor
            reward = torch.tensor(reward)

                # Update the labels with the computed rewards
            modified_dataset.extend([(image.cpu(), labels_one_hot, r.item(), y_image.cpu(), l, y_label) 
                                         for image, labels_one_hot, y_image, l, y_label, r in zip(images, labels_one_hot, y_images, labels, y_label, reward)])
    
    modified_dataloader = DataLoader(modified_dataset, batch_size=64, sampler=SequentialSampler(modified_dataset))
    return modified_dataloader


N1 = 5000
N2 = 5000
th_list_center = [0.25,0.33,0.5]
th_list_width = [0,0.05,0.1,0.15,0.2,0.25,0.3,0.35]
IK_threshold = [0.08]
runs = 4

torch.use_deterministic_algorithms(True)
set_seed(SEEED)
split_sizes = [N1, N2, len(concatenated_data) - N1 - N2]  # 5000 samples for each split
# Randomly split the dataset
train_dataset_1, train_dataset_2, train_dataset_3 = random_split(concatenated_data, split_sizes)
# Create DataLoader for each split
train_loader_1 = DataLoader(train_dataset_1, batch_size=64, sampler=SequentialSampler(train_dataset_1), collate_fn=custom_collate_fn)
train_loader_2 = DataLoader(train_dataset_2, batch_size=64, sampler=SequentialSampler(train_dataset_2), collate_fn=custom_collate_fn)
train_loader_3 = DataLoader(train_dataset_3, batch_size=64, sampler=SequentialSampler(train_dataset_3), collate_fn=custom_collate_fn_image)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, sampler=SequentialSampler(test_dataset))


for ik_th in IK_threshold:
    for th in th_list_center:
        for th_w in th_list_width:
            print(f"SEED:{SEEED}, N1:{N1}, N2={N2}, th={th}, th_w={th_w}, ik_th={ik_th}")
            acccc = []
            acccc2 = []
            for seed in range(runs):
                set_seed(seed)

                model_path = "model_checkpoint_unify_more.pth"

                model = CNN(seed).to(device)
                # model.load_state_dict(torch.load(model_path))
                criterion = nn.MSELoss()
                optimizer = torch.optim.Adam(model.parameters(), lr=0.0002)

                num_epochs = 100

                for epoch in range(num_epochs):
                    step = 0
                    total_loss = 0
                    num_batches = 0
                    for images, labels_one_hot, y_images, label,y_label in train_loader_1:
                        images = images.to(device)

                        labels_one_hot = labels_one_hot.to(device)
                        y_images = y_images.to(device)

                        outputs = model(images, y_images)
                        loss = criterion(outputs, labels_one_hot)

                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()

                        total_loss += loss.item()
                        num_batches += 1

                        step += 1
                        if step % 50 == 0:
                            print(f"Epoch [{epoch+1}/{num_epochs}], Step [{step}/{len(train_loader_1)}], Loss: {loss.item():.4f}")
                    average_loss = total_loss / num_batches
                    print(f"Epoch [{epoch+1}/{num_epochs}], Average Loss: {average_loss:.4f}")
                    if (average_loss<ik_th):
                        break

                    # Save the model every 5 epochs
                    if (epoch + 1) % 5 == 0:
                        torch.save(model.state_dict(), model_path)
                        print("Model saved.")

                # Example usage
                modified_training_set = modify_training_set(model, train_loader_2,th,th_w)
                print("Reward Calculated")



                model_path = "policy_more.pth"
                output_dim = 10  

                final_model = PolicyNetwork().to(device)

                num_epochs = 100
                average_loss = 0
                T = 2
                maxx = 0

                for i in range(T):
                    policy_net = PolicyNetwork().to(device)
                    optimizer = optim.Adam(policy_net.parameters(), lr=0.001)
                    for epoch in range(num_epochs):
                        step = 0
                        total_loss = 0
                        num_batches = 0
                        for images, labels_one_hot, reward, y_images, label,y_label in modified_training_set:
                            # print(images.size())
                            images = images.to(device)
                            # print(reward)
                            reward = reward.to(device)

                            prob = policy_net(images)
                            selected_probs = torch.sum(labels_one_hot.to(device) * prob, dim=1)
                            # log_probs = torch.log(prob + 1e-8)  # Add a small value to avoid log(0)

                            # Select log probabilities corresponding to the actions taken
                            loss = -torch.sum(reward * selected_probs * 10)

                            optimizer.zero_grad()
                            loss.backward()
                            optimizer.step()

                            total_loss += loss.item()
                            num_batches += 1

                            step += 1
                            if step % 20 == 0:
                                print(f"Epoch [{epoch+1}/{num_epochs}], Step [{step}/{len(modified_training_set)}], Reward: {-loss.item():.4f}")
                        average_loss = -total_loss / num_batches
                        print(f"Epoch [{epoch+1}/{num_epochs}], Average Loss: {average_loss:.4f}")

                        # Save the model every 5 epochs
                        if (epoch + 1) % 5 == 0:
                            torch.save(model.state_dict(), model_path)
                            print("Model saved.")
                        if (average_loss > maxx):
                            maxx = average_loss
                            final_model = copy.deepcopy(policy_net)

                ## Evaluating the model


                final_model.eval()
                with torch.no_grad():
                    acc=eval_model_set(final_model,train_loader_3)
                    acc2=eval_model_test_set(final_model,test_loader)
                acccc.append(acc)
                acccc2.append(acc2)

            acc_final = np.mean(acccc)
            std1 = np.std(acccc)
            acc2_final = np.mean(acccc2)
            std2 = np.std(acccc2)
            file_path_3 = f"{SEEED}_Avg_Final_{N1}_{N2}_{th}_{th_w}_{ik_th}_Transform.txt"
            folder_path = "final"

            # Ensure the folder exists
            os.makedirs(folder_path, exist_ok=True)

            # Define the file path
            file_path_3 = os.path.join(folder_path, file_path_3)
            with open(file_path_3, "w") as file:
                # Write the metrics to the file
                file.write(f"accuracy: {acc_final} ({std1})\n")
                file.write(f"Test accuracy: {acc2_final} ({std2})\n")

