import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import random
import argparse
import os
import pandas as pd
from collections import deque
import time

from torchvision import transforms, datasets
from torch.utils.data import DataLoader
import torchquantum as tq

class QFCModel(tq.QuantumModule):
    """
    Quantum-classical model for Fashion-MNIST classification.
    """
    class QLayer(tq.QuantumModule):
        def __init__(self, matrix, num_qubits, layer_num):
            super().__init__()
            self.n_wires = num_qubits
            self.matrix = matrix

            self.gates = nn.ModuleDict()
            for i in range(self.n_wires):
                key = f"qlayer{layer_num}_qubit{i}_gate_"
                if matrix[i][0] == 1:
                    self.gates[key + "rx"] = tq.RX(has_params=True, trainable=True)
                if matrix[i][1] == 1:
                    self.gates[key + "ry"] = tq.RY(has_params=True, trainable=True)
                if matrix[i][2] == 1:
                    self.gates[key + "rz"] = tq.RZ(has_params=True, trainable=True)

            self.entanglers = nn.ModuleList([tq.CNOT(has_params=False, trainable=False) for _ in range(self.n_wires)])

        def forward(self, qdev: tq.QuantumDevice):
            for key, gate in self.gates.items():
                strt_idx = key.find("qubit")
                wire = int(key[strt_idx + len("qubit"):].split('_')[0])
                gate(qdev, wires=wire)
            for i in range(self.n_wires):
                self.entanglers[i](qdev, wires=[i, (i + 1) % self.n_wires])

    def __init__(self, matrix, num_qubits, num_classes):
        super().__init__()
        self.n_wires = num_qubits
        self.num_classes = num_classes
        self.num_layers = len(matrix)
        self.pre_fc = nn.Linear(28 * 28, 2 ** self.n_wires)
        self.encoder = tq.AmplitudeEncoder()
        self.q_layers = nn.ModuleList([self.QLayer(matrix[i], num_qubits, i) for i in range(self.num_layers)])
        self.measure = tq.MeasureAll(tq.PauliZ)
        
        in_dim = self.n_wires
        self.head = nn.Sequential(
            nn.Linear(in_dim, 64),
            nn.GELU(),
            nn.Dropout(0.2),
            nn.Linear(64, num_classes),
        )

    def forward(self, x):
        qdev = tq.QuantumDevice(n_wires=self.n_wires, bsz=x.shape[0], device=x.device)
        bsz = x.shape[0]
        x = x.view(bsz, -1)
        x = self.pre_fc(x)
        x = F.normalize(x, p=2, dim=1)
        self.encoder(qdev, x)
        for i in range(self.num_layers):
            self.q_layers[i](qdev)
        x = self.measure(qdev)
        x = self.head(x)
        return F.log_softmax(x, dim=1)

class ArchitectureEnvironment:
    """
    Environment for RL-based quantum architecture search.
    State: Current architecture being built
    Action: Add gate (RX=0, RY=1, RZ=2, None=3) to current position
    Reward: Test accuracy after training the architecture
    """
    def __init__(self, num_qubits, num_layers, dataloader, device):
        self.num_qubits = num_qubits
        self.num_layers = num_layers
        self.dataloader = dataloader
        self.device = device
        self.total_positions = num_qubits * num_layers * 3  # 3 gates per qubit per layer
        self.reset()
    
    def reset(self):
        """Reset environment to start building new architecture."""
        self.current_architecture = np.zeros((self.num_layers, self.num_qubits, 3), dtype=int)
        self.current_position = 0
        self.done = False
        return self.get_state()
    
    def get_state(self):
        """Get current state representation."""
        arch_flat = self.current_architecture.flatten()
        position_encoding = np.zeros(self.total_positions)
        if self.current_position < self.total_positions:
            position_encoding[self.current_position] = 1
        
        state = np.concatenate([arch_flat, position_encoding])
        return state.astype(np.float32)
    
    def step(self, action):
        """Take action (0=RX, 1=RY, 2=RZ, 3=None) at current position."""
        if self.done:
            return self.get_state(), 0, True, {}
        
        layer = self.current_position // (self.num_qubits * 3)
        remaining = self.current_position % (self.num_qubits * 3)
        qubit = remaining // 3
        gate = remaining % 3
        
        if action < 3:
            self.current_architecture[layer, qubit, gate] = 1
        
        self.current_position += 1
        
        if self.current_position >= self.total_positions:
            self.done = True
            reward = self.evaluate_architecture()
        else:
            reward = 0  # No reward until architecture is complete
        
        return self.get_state(), reward, self.done, {}
    
    def evaluate_architecture(self):
        """Evaluate complete architecture by training and testing."""
        try:
            arch_matrix = self.current_architecture.tolist()
            
            if np.sum(self.current_architecture) == 0:
                return 0.0  # Empty architecture gets 0 reward
            
            model = QFCModel(arch_matrix, self.num_qubits, 10).to(self.device)
            optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)
            
            model.train()
            for epoch in range(8):  # Increased epochs for better training
                for batch_idx, (data, target) in enumerate(self.dataloader):
                    if batch_idx >= 30:  # More batches per epoch
                        break
                    data, target = data.to(self.device), target.to(self.device)
                    optimizer.zero_grad()
                    output = model(data)
                    loss = F.nll_loss(output, target)
                    loss.backward()
                    optimizer.step()
            
            model.eval()
            correct = 0
            total = 0
            with torch.no_grad():
                for batch_idx, (data, target) in enumerate(self.dataloader):
                    if batch_idx >= 20:  # More test batches for better evaluation
                        break
                    data, target = data.to(self.device), target.to(self.device)
                    output = model(data)
                    pred = output.argmax(dim=1, keepdim=True)
                    correct += pred.eq(target.view_as(pred)).sum().item()
                    total += data.size(0)
            
            accuracy = 100. * correct / total if total > 0 else 0.0
            
            normalized_acc = accuracy / 100.0
            
            num_gates = np.sum(self.current_architecture)
            total_possible_gates = self.num_qubits * self.num_layers * 3
            complexity_ratio = num_gates / total_possible_gates
            
            if normalized_acc > 0.6:  # Good accuracy gets complexity bonus
                reward = normalized_acc + 0.1 * (1 - complexity_ratio)
            else:  # Poor accuracy gets complexity penalty
                reward = normalized_acc - 0.05 * complexity_ratio
            
            return max(0.0, min(1.0, reward))  # Clamp to [0, 1]
            
        except Exception as e:
            print(f"Error evaluating architecture: {e}")
            return 0.0

class DQNAgent:
    """
    Deep Q-Network agent for quantum architecture search.
    """
    def __init__(self, state_size, action_size, lr=1e-3, gamma=0.99, epsilon=1.0, epsilon_decay=0.998, epsilon_min=0.05):
        self.state_size = state_size
        self.action_size = action_size
        self.lr = lr
        self.gamma = gamma
        self.epsilon = epsilon
        self.epsilon_decay = epsilon_decay
        self.epsilon_min = epsilon_min
        
        self.memory = deque(maxlen=10000)
        self.batch_size = 32
        
        self.q_network = self.build_network()
        self.target_network = self.build_network()
        self.optimizer = optim.Adam(self.q_network.parameters(), lr=lr)
        
        self.update_target_network()
        
    def build_network(self):
        """Build neural network for Q-function approximation."""
        return nn.Sequential(
            nn.Linear(self.state_size, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, self.action_size)
        )
    
    def remember(self, state, action, reward, next_state, done):
        """Store experience in replay buffer."""
        self.memory.append((state, action, reward, next_state, done))
    
    def act(self, state):
        """Choose action using epsilon-greedy policy."""
        if np.random.random() <= self.epsilon:
            return random.randrange(self.action_size)
        
        state_tensor = torch.FloatTensor(state).unsqueeze(0)
        q_values = self.q_network(state_tensor)
        return np.argmax(q_values.cpu().data.numpy())
    
    def replay(self):
        """Train the model on a batch of experiences."""
        if len(self.memory) < self.batch_size:
            return
        
        batch = random.sample(self.memory, self.batch_size)
        states = torch.FloatTensor([e[0] for e in batch])
        actions = torch.LongTensor([e[1] for e in batch])
        rewards = torch.FloatTensor([e[2] for e in batch])
        next_states = torch.FloatTensor([e[3] for e in batch])
        dones = torch.BoolTensor([e[4] for e in batch])
        
        current_q_values = self.q_network(states).gather(1, actions.unsqueeze(1))
        next_q_values = self.target_network(next_states).max(1)[0].detach()
        target_q_values = rewards + (self.gamma * next_q_values * ~dones)
        
        loss = F.mse_loss(current_q_values.squeeze(), target_q_values)
        
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay
    
    def update_target_network(self):
        """Copy weights from main network to target network."""
        self.target_network.load_state_dict(self.q_network.state_dict())

def get_fashion_mnist_dataloader(batch_size=64):
    """Get Fashion-MNIST dataloader."""
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.2860,), (0.3530,))
    ])
    
    dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
    return dataloader

def rl_architecture_search(num_qubits, num_layers, episodes, device):
    """
    Perform RL-based quantum architecture search.
    """
    print(f"Starting RL-based QAS with {num_qubits} qubits, {num_layers} layers")
    
    dataloader = get_fashion_mnist_dataloader(batch_size=32)  # Smaller batch for speed
    
    env = ArchitectureEnvironment(num_qubits, num_layers, dataloader, device)
    state_size = env.get_state().shape[0]
    action_size = 4  # RX, RY, RZ, None
    
    agent = DQNAgent(state_size, action_size)
    
    scores = []
    best_score = 0
    best_architecture = None
    
    for episode in range(episodes):
        state = env.reset()
        total_reward = 0
        steps = 0
        
        while True:
            action = agent.act(state)
            next_state, reward, done, _ = env.step(action)
            agent.remember(state, action, reward, next_state, done)
            
            state = next_state
            total_reward += reward
            steps += 1
            
            if done:
                break
        
        scores.append(total_reward)
        
        agent.replay()
        
        if episode > 10:  # Start training after some experience
            agent.replay()
        
        if episode % 20 == 0:
            agent.update_target_network()
        
        if total_reward > best_score:
            best_score = total_reward
            best_architecture = env.current_architecture.copy()
        
        print(f"Episode {episode + 1}/{episodes}, Score: {total_reward:.4f}, "
              f"Epsilon: {agent.epsilon:.3f}, Best: {best_score:.4f}")
    
    return best_architecture, scores, agent

def evaluate_final_architecture(architecture, num_qubits, device, epochs=50, out_dir="RL_QAS_results"):
    """Evaluate the best found architecture with full training and log performance."""
    print(f"Evaluating final architecture with {epochs} epochs of training...")
    
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.2860,), (0.3530,))
    ])
    
    train_dataset = datasets.FashionMNIST(root='./data', train=True, transform=transform)
    test_dataset = datasets.FashionMNIST(root='./data', train=False, transform=transform)
    
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=2)
    
    arch_matrix = architecture.tolist()
    model = QFCModel(arch_matrix, num_qubits, 10).to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    
    training_history = {
        'epoch': [],
        'train_loss': [],
        'train_acc': [],
        'test_acc': [],
        'lr': []
    }
    
    for epoch in range(epochs):
        model.train()
        train_loss = 0
        correct = 0
        total = 0
        
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = F.nll_loss(output, target)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            total += data.size(0)
        
        epoch_train_loss = train_loss / len(train_loader)
        epoch_train_acc = 100. * correct / total
        current_lr = optimizer.param_groups[0]['lr']
        
        model.eval()
        test_correct = 0
        test_total = 0
        
        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(device), target.to(device)
                output = model(data)
                pred = output.argmax(dim=1, keepdim=True)
                test_correct += pred.eq(target.view_as(pred)).sum().item()
                test_total += data.size(0)
        
        epoch_test_acc = 100. * test_correct / test_total
        
        training_history['epoch'].append(epoch + 1)
        training_history['train_loss'].append(epoch_train_loss)
        training_history['train_acc'].append(epoch_train_acc)
        training_history['test_acc'].append(epoch_test_acc)
        training_history['lr'].append(current_lr)
        
        scheduler.step()
        
        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {epoch_train_loss:.4f}, "
                  f"Train Acc: {epoch_train_acc:.2f}%, Test Acc: {epoch_test_acc:.2f}%, LR: {current_lr:.6f}")
    
    training_df = pd.DataFrame(training_history)
    training_df.to_csv(os.path.join(out_dir, "final_training_history.csv"), index=False)
    print(f"Training history saved to {os.path.join(out_dir, 'final_training_history.csv')}")
    
    final_test_acc = training_history['test_acc'][-1]
    print(f"Final Test Accuracy: {final_test_acc:.2f}%")
    
    return final_test_acc

def main():
    parser = argparse.ArgumentParser(description="RL-based Quantum Architecture Search for Fashion-MNIST")
    parser.add_argument("--num_qubits", type=int, default=4, help="Number of qubits")
    parser.add_argument("--num_layers", type=int, default=2, help="Number of quantum layers")
    parser.add_argument("--episodes", type=int, default=100, help="Number of RL episodes")
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    parser.add_argument("--out_dir", type=str, default="RL_QAS_results", help="Output directory")
    
    args = parser.parse_args()
    
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    

    os.makedirs(args.out_dir, exist_ok=True)
    
    start_time = time.time()
    best_architecture, scores, agent = rl_architecture_search(
        args.num_qubits, args.num_layers, args.episodes, device
    )
    search_time = time.time() - start_time
    
    print(f"\nRL Search completed in {search_time:.2f} seconds")
    print(f"Best architecture found:")
    print(best_architecture)
    
    final_accuracy = evaluate_final_architecture(best_architecture, args.num_qubits, device, epochs=50, out_dir=args.out_dir)
    
    results = {
        'episode': list(range(len(scores))),
        'score': scores
    }
    results_df = pd.DataFrame(results)
    results_df.to_csv(os.path.join(args.out_dir, "rl_training_scores.csv"), index=False)
    
    arch_df = pd.DataFrame({
        'Architecture': [str(best_architecture.tolist())],
        'Final_Accuracy': [final_accuracy],
        'Search_Time': [search_time],
        'Episodes': [args.episodes]
    })
    arch_df.to_csv(os.path.join(args.out_dir, "best_architecture.csv"), index=False)
    
    print(f"\nResults saved to {args.out_dir}/")
    print(f"Best architecture test accuracy: {final_accuracy:.2f}%")

if __name__ == "__main__":
    main()
