import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import time
import matplotlib.pyplot as plt
import json
import os
from PIL import Image
import torchvision.transforms as transforms
import easyocr
import numpy as np
import argparse
from Levenshtein import distance as levenshtein_distance

batch_size = 8
block_size = 1024
patch_size = 4
num_iters = 2000
learning_rate = 1e-3
n_embd = 128
n_head = 2
n_layers_encoder = 2
n_layers_latent = 4
n_layers_decoder = 2
dropout = 0.1
vocab_size = 256

transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

class LayerNorm(nn.Module):
    def __init__(self, ndim, bias=True):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(ndim))
        self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None

    def forward(self, x):
        return F.layer_norm(x, self.weight.shape, self.weight, self.bias, 1e-5)

class Block(nn.Module):
    def __init__(self, n_embd, n_head, dropout):
        super().__init__()
        self.ln1 = LayerNorm(n_embd)
        self.attn = nn.MultiheadAttention(embed_dim=n_embd, num_heads=n_head, dropout=dropout, batch_first=True)
        self.ln2 = LayerNorm(n_embd)
        self.mlp = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.GELU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        attn_out, _ = self.attn(self.ln1(x), self.ln1(x), self.ln1(x))
        x = x + attn_out
        x = x + self.mlp(self.ln2(x))
        return x


class LatentTransformer(nn.Module):
    def __init__(self, n_embd, n_head, n_layers, dropout):
        super().__init__()
        self.blocks = nn.ModuleList([Block(n_embd, n_head, dropout) for _ in range(n_layers)])
        self.ln_f = LayerNorm(n_embd)

    def forward(self, x):
        for block in self.blocks:
            x = block(x)
        return self.ln_f(x)


class ImageEncoder(nn.Module):
    def __init__(self, n_embd):
        super().__init__()
        self.reader = easyocr.Reader(['en'], gpu=torch.cuda.is_available())
        self.feature_extractor = self.reader.recognizer.module.FeatureExtraction
        self.projection = nn.Linear(256, n_embd)
        
        self.feature_extractor.eval()
        
        if torch.cuda.is_available():
            self.feature_extractor = self.feature_extractor.cuda()
            self.projection = self.projection.cuda()
    
    def forward(self, images):
        if torch.cuda.is_available() and not images.is_cuda:
            images = images.cuda()
        
        if images.shape[1] == 3:
            images = 0.299 * images[:, 0:1] + 0.587 * images[:, 1:2] + 0.114 * images[:, 2:3]
        
        with torch.no_grad():
            features = self.feature_extractor(images)
            features = features.mean(dim=(2, 3))
        
        return self.projection(features)

class GlyphDecoder(nn.Module):
    def __init__(self, vocab_size, n_embd, n_head, n_layers, dropout):
        super().__init__()
        self.transformer_blocks = nn.ModuleList([Block(n_embd, n_head, dropout) for _ in range(n_layers)])
        self.ln_f = LayerNorm(n_embd)
        self.char_proj = nn.Linear(n_embd, vocab_size)
    
    def forward(self, x):
        for block in self.transformer_blocks:
            x = block(x)
        x = self.ln_f(x)
        return self.char_proj(x)

class ProjectionFusion(nn.Module):
    def __init__(self, n_embd, dropout=0.1):
        super().__init__()
        self.image_projection = nn.Sequential(
            nn.Linear(n_embd, n_embd),
            nn.LayerNorm(n_embd),
            nn.GELU(),
            nn.Linear(n_embd, n_embd),
            nn.Dropout(dropout)
        )
        
        self.text_projection = nn.Sequential(
            nn.Linear(n_embd, n_embd),
            nn.LayerNorm(n_embd),
            nn.GELU(),
            nn.Linear(n_embd, n_embd),
            nn.Dropout(dropout)
        )
        
        self.fusion_layer = nn.Sequential(
            nn.Linear(n_embd, n_embd),
            nn.LayerNorm(n_embd),
            nn.Dropout(dropout)
        )
        
        self.modality_gate = nn.Sequential(
            nn.Linear(n_embd * 2, n_embd),
            nn.Sigmoid()
        )
    
    def forward(self, text_features, image_features, image_mask=None):
        projected_text = self.text_projection(text_features)
        projected_image = self.image_projection(image_features)
        
        if image_mask is not None:
            projected_image = projected_image * image_mask.unsqueeze(-1)
        
        combined = torch.cat([projected_text, projected_image], dim=-1)
        gate = self.modality_gate(combined)
        
        fused_features = gate * projected_image + (1 - gate) * projected_text
        
        output = self.fusion_layer(fused_features)
        
        return output

class GlyphDecodeModel(nn.Module):
    def __init__(self, vocab_size, n_embd, n_head, n_layers, dropout):
        super().__init__()
        self.char_embedding = nn.Embedding(vocab_size, n_embd)
        self.image_encoder = ImageEncoder(n_embd)
        self.pos_embedding = nn.Parameter(torch.zeros(1, 100, n_embd))
        
        self.fusion = ProjectionFusion(n_embd, dropout)
        
        self.transformer = LatentTransformer(n_embd, n_head, n_layers, dropout)
        self.decoder = GlyphDecoder(vocab_size, n_embd, n_head, n_layers, dropout)
        self.apply(self._init_weights)
    
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
    
    def forward(self, damaged_chars, char_images=None, targets=None):
        B, T = damaged_chars.shape
        
        char_emb = self.char_embedding(damaged_chars)
        
        char_emb = char_emb + self.pos_embedding[:, :T, :]
        
        if char_images is not None:
            B, T, C, H, W = char_images.shape
            flat_images = char_images.reshape(-1, C, H, W)
            
            img_mask = (torch.sum(flat_images, dim=(1, 2, 3)) > 0).float()
            
            all_img_embs = self.image_encoder(flat_images)
            
            img_embs = all_img_embs.reshape(B, T, -1)
            
            img_embs = img_embs + self.pos_embedding[:, :T, :]
            
            img_mask = img_mask.reshape(B, T)
            
            x = self.fusion(char_emb, img_embs, img_mask)
        else:
            x = char_emb
        
        x = self.transformer(x)
        
        logits = self.decoder(x)
        
        if targets is not None:
            loss = F.cross_entropy(logits.view(B * T, -1), targets.view(-1))
        else:
            loss = None
        
        return logits, loss

class GlyphDataset:
    def __init__(self, data_path, max_length=50, test_split=10, base_dir=""):
        self.data = []
        self.max_length = max_length
        self.test_split = test_split
        self.base_dir = base_dir
        
        with open(data_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
        
        self.test_data = data[:test_split] if test_split > 0 else []
        self.train_data = data[test_split:] if test_split > 0 else data
        
        for item in self.train_data:
            damaged_text = item['damaged_text']
            original_text = item['original_text']
            char_data = item['char_data']
            
            image_paths = [os.path.join(self.base_dir, char_info['image_data']) for char_info in char_data]
            
            self.data.append({
                'damaged_text': damaged_text,
                'original_text': original_text,
                'image_paths': image_paths
            })
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        damaged_text = item['damaged_text']
        original_text = item['original_text']
        image_paths = item['image_paths']
        
        
        damaged_chars = [ord(c) % 256 for c in damaged_text]
        original_chars = [ord(c) % 256 for c in original_text]
        
        if len(damaged_chars) < self.max_length:
            damaged_chars = damaged_chars + [0] * (self.max_length - len(damaged_chars))
            original_chars = original_chars + [0] * (self.max_length - len(original_chars))
        else:
            damaged_chars = damaged_chars[:self.max_length]
            original_chars = original_chars[:self.max_length]
        
        images = []
        for path in image_paths[:self.max_length]:
            try:
                img = Image.open(path).convert('L')
                img = img.resize((32, 32))
                img = np.array(img) / 255.0
                img = torch.FloatTensor(img).unsqueeze(0)
                images.append(img)
            except Exception as e:
                images.append(torch.zeros(1, 32, 32))
        
        while len(images) < self.max_length:
            images.append(torch.zeros(1, 32, 32))
        
        return {
            'damaged_chars': torch.tensor(damaged_chars),
            'original_chars': torch.tensor(original_chars),
            'images': torch.stack(images)
        }
    
    def get_test_data(self):
        return self.test_data

def train_glyph_model():
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f'using device: {device}')
    
    if torch.cuda.is_available():
        print(f'GPU: {torch.cuda.get_device_name(0)}')
        print(f'CUDA version: {torch.version.cuda}')
    
    dataset = GlyphDataset(os.path.join(args.base_dir, 'data/data.json'), test_split=10, base_dir=args.base_dir)
    
    dataloader = torch.utils.data.DataLoader(
        dataset, 
        batch_size=batch_size, 
        shuffle=True,
        num_workers=0
    )
    
    model = GlyphDecodeModel(
        vocab_size=vocab_size,
        n_embd=n_embd,
        n_head=n_head,
        n_layers=n_layers_latent,
        dropout=dropout
    ).to(device)
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    
    iter_num = 0
    start_time = time.time()
    loss_list = []
    
    print('\ntraining start...')
    while iter_num < num_iters:
        for batch in dataloader:
            damaged_chars = batch['damaged_chars'].to(device)
            original_chars = batch['original_chars'].to(device)
            images = batch['images'].to(device)
            
            optimizer.zero_grad(set_to_none=True)
            logits, loss = model(damaged_chars, images, original_chars)
            loss.backward()
            optimizer.step()
            
            loss_list.append(loss.item())
            
            if iter_num % 10 == 0:
                print(f'step {iter_num}/{num_iters} | loss {loss.item():.4f}')
            
            iter_num += 1
            if iter_num >= num_iters:
                break
    
    avg_time = time.time() - start_time
    print(f'\ntraining time: {avg_time:.2f}s')
    
    torch.save(model.state_dict(), f"model/{args.modelname}.pt")
    
    plt.figure(figsize=(16, 6))
    plt.plot(loss_list, label='Loss', color='blue', linewidth=2)
    plt.xlabel('Iterations', fontsize=14)
    plt.ylabel('Loss', fontsize=14)
    plt.title('Loss / Iterations', fontsize=16)
    plt.legend(fontsize=12)
    plt.tight_layout()
    plt.show()
    
    return model, dataset

def restore_text(model, damaged_text, image_paths=None, base_dir=""):
    model.eval()
    device = next(model.parameters()).device
    
    damaged_chars = [ord(c) % 256 for c in damaged_text]
    damaged_chars_tensor = torch.tensor(damaged_chars).unsqueeze(0).to(device)
    
    images = None
    if image_paths:
        loaded_images = []
        for path in image_paths:
            try:
                full_path = os.path.join(base_dir, path) if base_dir else path
                img = Image.open(full_path).convert('L')
                img = img.resize((32, 32))
                img = np.array(img) / 255.0
                img = torch.FloatTensor(img).unsqueeze(0)
                loaded_images.append(img)
            except Exception as e:
                loaded_images.append(torch.zeros(1, 32, 32))
        
        while len(loaded_images) < len(damaged_chars):
            loaded_images.append(torch.zeros(1, 32, 32))
        
        loaded_images = loaded_images[:len(damaged_chars)]
        
        images = torch.stack(loaded_images).unsqueeze(0).to(device)
    
    with torch.no_grad():
        logits, _ = model(damaged_chars_tensor, images)
        predictions = torch.argmax(logits, dim=-1)
    
    restored_text = ''.join(chr(p.item()) for p in predictions[0] if p.item() > 0)
    return restored_text

def evaluate_test_data(model, test_data, base_dir=""):
    model.eval()
    device = next(model.parameters()).device
    
    results = []
    total_accuracy = 0
    total_ned = 0
    exact_matches = 0
    
    total_inference_time = 0
    
    total_test_start_time = time.time()
    
    for idx, item in enumerate(test_data):
        damaged_text = item['damaged_text']
        original_text = item['original_text']
        char_data = item['char_data']
        
        image_paths = [char_info['image_data'] for char_info in char_data]
        
        start_time = time.time()
        
        restored_text = restore_text(model, damaged_text, image_paths, base_dir)
        
        inference_time = time.time() - start_time
        total_inference_time += inference_time
        
        accuracy = sum(1 for a, b in zip(restored_text, original_text) if a == b) / max(len(original_text), 1)
        
        is_exact_match = (restored_text == original_text)
        if is_exact_match:
            exact_matches += 1
        
        if len(original_text) > 0:
            edit_distance = levenshtein_distance(original_text, restored_text)
            ned = 1.0 - (edit_distance / max(len(original_text), len(restored_text)))
        else:
            ned = 1.0 if len(restored_text) == 0 else 0.0
        
        results.append({
            'idx': idx,
            'damaged_text': damaged_text,
            'restored_text': restored_text,
            'original_text': original_text,
            'accuracy': accuracy,
            'ned': ned,
            'exact_match': is_exact_match,
            'inference_time': inference_time
        })
        
        total_accuracy += accuracy
        total_ned += ned
        

    
    total_test_time = time.time() - total_test_start_time
    
    avg_accuracy = total_accuracy / len(results) if results else 0
    avg_ned = total_ned / len(results) if results else 0
    exact_match_rate = exact_matches / len(results) if results else 0
    avg_inference_time = total_inference_time / len(results) if results else 0
    
    return results

def test_with_samples(model, sample_json_path, base_dir=""):
    print(f"\ntesting with sample data '{sample_json_path}'...")
    
    with open(os.path.join(base_dir, sample_json_path), 'r', encoding='utf-8') as f:
        sample_data = json.load(f)
    
    results = evaluate_test_data(model, sample_data, base_dir)
    
    try:
        path_parts = sample_json_path.split('/')
        for part in path_parts:
            if part.startswith('test_'):
                embedding_type = part.replace('test_', '')
            if part.startswith('images_word_p'):
                p_value = part.replace('images_word_p', '')
        
        result_filename = f"linked_words_{embedding_type}_{p_value}.txt"
        with open(result_filename, 'w', encoding='utf-8') as f:
            for item in results:
                f.write(f"{item['original_text']}||{item['damaged_text']}||{item['restored_text']}\n")
        
        print(f"Results saved to {result_filename}")
    except Exception as e:
        print(f"Error saving results to file: {e}")
    
    return results

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='glyphdecode')
    parser.add_argument('--train', action='store_true', help='Train the model')
    parser.add_argument('--test', action='store_true', help='Test with trained model')
    parser.add_argument('--sample', type=str, help='Sample JSON file path (e.g., samples/test.json)')
    parser.add_argument('--modelname', type=str, default='glyph_model', help='Model name')
    parser.add_argument('--cpu', action='store_true', help='Force CPU usage (ignore GPU)')
    parser.add_argument('--base_dir', type=str, default='', help='Base directory for dataset')
    args = parser.parse_args()
    
    if args.cpu:
        device = 'cpu'
        os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
        print("GPU usage disabled. Using CPU.")
    else:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        if device == 'cuda':
            print(f"GPU: {torch.cuda.get_device_name(0)}")
        else:
            print("No GPU available. Using CPU.")
    
    model_dir = os.path.join(args.base_dir, "model") if args.base_dir else "model"
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    model_path = os.path.join(model_dir, f"{args.modelname}.pt")
    
    model = GlyphDecodeModel(
        vocab_size=vocab_size,
        n_embd=n_embd,
        n_head=n_head,
        n_layers=n_layers_latent,
        dropout=dropout
    ).to(device)
    
    if args.train:
        dataset = GlyphDataset(os.path.join(args.base_dir, 'data/data.json'), test_split=10, base_dir=args.base_dir)
        test_data = dataset.get_test_data()
        
        print("training model...")
        model, _ = train_glyph_model()
        
        if args.test:
            print("\ntesting with test data...")
            test_results = evaluate_test_data(model, test_data, args.base_dir)
    
    elif args.test or args.sample:
        if os.path.exists(model_path):
            print(f"loading existing model file '{model_path}'...")
            model.load_state_dict(torch.load(model_path))
            device = 'cuda' if torch.cuda.is_available() else 'cpu'
            model = model.to(device)
        else:
            print(f"error: model file '{model_path}' does not exist.")
            exit(1)
        
        if args.test:
            dataset = GlyphDataset(os.path.join(args.base_dir, 'data/data.json'), test_split=10, base_dir=args.base_dir)
            test_data = dataset.get_test_data()
            print("\ntesting with test data...")
            test_results = evaluate_test_data(model, test_data, args.base_dir)
        
        if args.sample:
            sample_results = test_with_samples(model, args.sample, args.base_dir)
    
    else:
        dataset = GlyphDataset(os.path.join(args.base_dir, 'data/data.json'), test_split=10, base_dir=args.base_dir)
        test_data = dataset.get_test_data()
        
        if os.path.exists(model_path):
            print(f"loading existing model file '{model_path}'...")
            model.load_state_dict(torch.load(model_path))
            device = 'cuda' if torch.cuda.is_available() else 'cpu'
            model = model.to(device)
        else:
            print(f"model file does not exist. start training...")
            model, _ = train_glyph_model()
        
        print("\ntesting with test data...")
        test_results = evaluate_test_data(model, test_data, args.base_dir)
    
    model_params = sum(p.numel() for p in model.parameters())
    ocr_encoder_params = sum(p.numel() for p in model.image_encoder.feature_extractor.parameters())
    total_params = model_params + ocr_encoder_params

    model_size = model_params * 4 / (1024 * 1024)
    ocr_encoder_size = ocr_encoder_params * 4 / (1024 * 1024)
    total_size = total_params * 4 / (1024 * 1024)
