import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
import os
from tqdm import tqdm

# Add this import at the top of finetune_mlp.py
import collections


class MLP(nn.Module):
    """Single hidden layer MLP network"""
    def __init__(self, input_dim=60, hidden_dim=128, output_dim=9):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
        
    def forward(self, x):
        return self.net(x)

class FinetuneWrapper:
    def __init__(self, pretrained_model, mlp_model=None, lambda_omega=0.05, epsilon_A=1e-6):
        """
        Args:
            pretrained_model: Pretrained model
            mlp_model: MLP model for fine-tuning (optional)
            lambda_omega: Physics-informed regularization coefficient
            epsilon_A: Numerical stability constant
        """
        self.pretrained_model = pretrained_model
        self.pretrained_model.eval()  # Freeze pretrained model
        
        # Initialize MLP
        self.mlp = mlp_model if mlp_model else MLP().to(device)
        self.optimizer = optim.Adam(self.mlp.parameters(), lr=1e-3)
        self.loss_fn = nn.MSELoss()
        
        # Physics-informed regularization parameters
        self.lambda_omega = lambda_omega
        self.epsilon_A = epsilon_A
        
    def compute_diagonal_dominance_regularization(self):
        """Compute diagonal dominance physics-informed regularization term"""
        # Get the weight matrix A of the second linear layer (output_dim × hidden_dim)
        A = self.mlp.net[2].weight  # shape: (output_dim, hidden_dim)
        
        # Calculate diagonal dominance ω
        diagonal_terms = torch.abs(torch.diag(A))  # Absolute values of diagonal elements
        off_diagonal_sums = torch.sum(torch.abs(A), dim=1) - diagonal_terms  # Sum of absolute values of off-diagonal elements
        
        # Calculate relative dominance for each diagonal element
        relative_dominance = (diagonal_terms - off_diagonal_sums) / (torch.sum(torch.abs(A), dim=1) + self.epsilon_A)
        
        # Sum all elements and apply ReLU
        omega = torch.sum(relative_dominance)
        regularization = self.lambda_omega * torch.relu(omega)
        
        return regularization
        
    def predict(self, x_img):
        """Inference process"""
        with torch.no_grad():
            # Pretrained model generates results
            obs_cond = x_img.flatten(start_dim=1)
            noise = torch.rand(1, 16, 9).to(device)  # pred_horizon=16, action_dim=9
            x0 = noise.expand(x_img.shape[0], -1, -1)
            
            # Simplified inference - actual implementation should follow pretrained model's specific inference process
            pretrain_output = self.pretrained_model(x0, torch.tensor([0.5]).to(device), 
                                                  global_cond=obs_cond)
        
        # MLP processing - maintain 60-dimensional input
        mlp_input = obs_cond  # Use original observation condition directly (60-dimensional)
        final_output = self.mlp(mlp_input)
        return final_output
    
    def train_step(self, x_img, target_action):
        """Training step"""
        self.optimizer.zero_grad()
        
        # Pretrained model generates results (no gradient tracking)
        with torch.no_grad():
            obs_cond = x_img.flatten(start_dim=1)
            noise = torch.rand(1, 16, 9).to(device)
            x0 = noise.expand(x_img.shape[0], -1, -1)
            pretrain_output = self.pretrained_model(x0, torch.tensor([0.5]).to(device),
                                                 global_cond=obs_cond)
        
        # Calculate difference
        pretrain_action = pretrain_output.mean(dim=1)  # Simplified processing
        delta = target_action.mean(dim=1) - pretrain_action  # Maintain consistent dimensions
        
        # Train MLP - use original observation condition (60-dimensional)
        mlp_input = obs_cond
        pred_delta = self.mlp(mlp_input)
        
        # Calculate MSE loss
        mse_loss = self.loss_fn(pred_delta, delta)
        
        # Calculate physics-informed regularization term
        physics_loss = self.compute_diagonal_dominance_regularization()
        
        # Total loss = MSE loss + physics-informed regularization term
        total_loss = mse_loss + physics_loss
        
        total_loss.backward()
        self.optimizer.step()
        
        return total_loss.item()
    
    def save_models(self, path):
        """Save both models"""
        torch.save({
            'pretrained_state_dict': self.pretrained_model.state_dict(),
            'mlp_state_dict': self.mlp.state_dict()
        }, path)
    
    @classmethod
    def load_models(cls, pretrained_model, path):
        """Load models"""
        checkpoint = torch.load(path)
        pretrained_model.load_state_dict(checkpoint['pretrained_state_dict'])
        
        mlp = MLP()
        mlp.load_state_dict(checkpoint['mlp_state_dict'])
        
        return cls(pretrained_model, mlp)

# Usage example
if __name__ == "__main__":
    import argparse
    from unet import ConditionalUnet1D
    from kitchen_lowdim_dataset import KitchenLowdimDataset
    
    # Argument parsing
    parser = argparse.ArgumentParser()
    parser.add_argument('--checkpoint', type=str, required=True,
                      help='Path to pretrained model checkpoint')
    parser.add_argument('--dataset_dir', type=str, 
                      default='./data/kitchen',
                      help='Path to kitchen dataset')
    parser.add_argument('--mode', type=str, choices=['train', 'test'], default='train',
                      help='Run in train or test mode')
    parser.add_argument('--epochs', type=int, default=1000,
                      help='Number of training epochs (train mode only)')
    parser.add_argument('--batch_size', type=int, default=32,
                      help='Batch size for training (train mode only)')
    args = parser.parse_args()
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # 1. Load pretrained model
    pretrained_model = ConditionalUnet1D(
        input_dim=9,
        global_cond_dim=60
    ).to(device)
    
    # Load checkpoint (handle multiple formats)
    state_dict = torch.load(args.checkpoint)
    if 'noise_pred_net' in state_dict:
        # Original format
        pretrained_model.load_state_dict(state_dict['noise_pred_net'])
    elif 'model' in state_dict:
        # Simple model format
        pretrained_model.load_state_dict(state_dict['model'])
    else:
        # Direct state_dict format
        pretrained_model.load_state_dict(state_dict)
    
    # 2. Create finetune wrapper
    finetuner = FinetuneWrapper(pretrained_model)
    
    # 3. Prepare data
    dataset = KitchenLowdimDataset(
        dataset_dir=args.dataset_dir,
        horizon=16,
    )
    dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
    
    if args.mode == 'train':
        # Training mode
        print(f"Starting training for {args.epochs} epochs...")
        for epoch in range(args.epochs):
            total_loss = 0.0
            for batch in tqdm(dataloader, desc=f"Epoch {epoch+1}/{args.epochs}"):
                x_img = batch['obs'][:, :1].to(device)  # obs_horizon=1
                target_action = batch['action'].to(device)
                
                loss = finetuner.train_step(x_img, target_action)
                total_loss += loss
            
            avg_loss = total_loss / len(dataloader)
            print(f"Epoch {epoch+1}: Avg Loss = {avg_loss:.4f}")
            
            # Save model periodically
            if (epoch + 1) % 100 == 0:
                save_path = f'finetuned_model_epoch{epoch+1}.pth'
                finetuner.save_models(save_path)
                print(f"Model saved to {save_path}")
        
        # Save final model after training
        save_path = 'finetuned_model_final.pth'
        finetuner.save_models(save_path)
        print(f"Final model saved to {save_path}")
    
    # Test mode (reference flow_kitchen.py)
    from diffusion_policy.env.kitchen.v0 import KitchenAllV0
    max_steps = 280
    env = KitchenAllV0(use_abs_action=False)
    test_start_seed = 10000
    n_test = 5  # Test 5 seeds
    obs_horizon = 1  # Observation window size set by flow_kitchen.py

    total_rewards = []
    
    for epoch in range(n_test):
        seed = test_start_seed + epoch
        env.seed(seed)
        obs = env.reset()
        
        obs_deque = collections.deque([obs] * obs_horizon, maxlen=obs_horizon)
        rewards = []
        done = False
        step_idx = 0
        
        with tqdm(total=max_steps, desc=f"Testing seed {seed}") as pbar:
            while not done:
                # Prepare input
                x_img = np.stack([x for x in obs_deque])
                x_img = torch.from_numpy(x_img).to(device, dtype=torch.float32)
                
                # Use finetuner to predict action
                with torch.no_grad():
                    action = finetuner.predict(x_img.unsqueeze(0))[0].cpu().numpy()
                
                # Execute action
                obs, reward, done, info = env.step(action)
                obs_deque.append(obs)
                rewards.append(reward)
                
                # Update progress
                step_idx += 1
                pbar.update(1)
                pbar.set_postfix(reward=reward)
                
                if step_idx >= max_steps or sum(rewards) == 4:
                    done = True
        
        # Record result
        episode_reward = sum(rewards)
        total_rewards.append(episode_reward)
        print(f"Seed {seed}: Reward={episode_reward}")
    
    # Output statistical results
    print("\nTest Results:")
    print(f"Average Reward: {np.mean(total_rewards):.2f}")
    
    # 5. Save model
    save_path = 'finetuned_model.pth'
    finetuner.save_models(save_path)
    print(f"\nModel saved to {save_path}")
