import os
import math
import random
import sys
import time
from pathlib import Path
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision import transforms, models

INPUT_DIR = Path("../input/severstal-steel-defect-detection")
TRAIN_CSV = INPUT_DIR / "train.csv"
TRAIN_IMG_DIR = INPUT_DIR / "train_images"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

IMG_SIZE = 512
BATCH_SIZE = 16
LR = 1e-4
WEIGHT_DECAY = 1e-5
NUM_EPOCHS = 8
NUM_WORKERS = 2
SEED = 42

CLASS_WEIGHTS = torch.tensor([1.0, 1.5, 2.0, 1.2], device=DEVICE)

torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

def rle_count(rle):
    if not isinstance(rle, str) or rle.strip() == "":
        return 0
    parts = list(map(int, rle.split()))
    return sum(parts[1::2])

df = pd.read_csv(TRAIN_CSV)

if 'ImageId_ClassId' in df.columns:
    tmp = df['ImageId_ClassId'].str.rsplit("_", n=1, expand=True)
    df['ImageId'] = tmp[0]
    df['ClassId'] = tmp[1].astype(int)

df = df[['ImageId','ClassId','EncodedPixels']]

example_img = next(TRAIN_IMG_DIR.iterdir())
with Image.open(example_img) as im:
    W, H = im.size
IMAGE_AREA = W * H
print(f"Probed image size: {W}x{H}, area={IMAGE_AREA}")

agg = {}
for _, row in df.iterrows():
    img = row['ImageId']
    c = int(row['ClassId']) - 1
    cnt = rle_count(row['EncodedPixels'])
    if img not in agg:
        agg[img] = np.zeros(4, dtype=np.float32)
    agg[img][c] += cnt

rows = []
for img, arr in agg.items():
    frac = arr / IMAGE_AREA
    rows.append([img, *frac])

targets_df = pd.DataFrame(rows, columns=['ImageId','frac_c1','frac_c2','frac_c3','frac_c4'])

all_images = sorted([p.name for p in TRAIN_IMG_DIR.iterdir() if p.suffix.lower() in ('.jpg','.png')])
targets_df = targets_df.set_index('ImageId').reindex(all_images, fill_value=0).reset_index()
print("Prepared targets for", len(targets_df), "images")

class SeverstalFractionDataset(Dataset):
    def __init__(self, df_targets, images_dir, img_size=512, is_train=True):
        self.images_dir = Path(images_dir)
        self.df = df_targets.copy().reset_index(drop=True)
        self.img_size = img_size
        self.is_train = is_train
        
        self.train_tfms = transforms.Compose([
            transforms.RandomResizedCrop(img_size, scale=(0.6,1.0), ratio=(0.9,1.1)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(10),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1),
            transforms.ToTensor(),
            transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
        ])
        
        self.val_tfms = transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.ToTensor(),
            transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
        ])
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = self.images_dir / row['ImageId']
        img = Image.open(img_path).convert("RGB")
        
        if self.is_train:
            img = self.train_tfms(img)
        else:
            img = self.val_tfms(img)
            
        target = torch.tensor([row['frac_c1'], row['frac_c2'], row['frac_c3'], row['frac_c4']], dtype=torch.float32)
        return img, target, row['ImageId']

class ResNetRegressor(nn.Module):
    def __init__(self, backbone_name='resnet34', pretrained=True, out_dim=4, dropout=0.3):
        super().__init__()
        
        if backbone_name == 'resnet34':
            self.backbone = models.resnet34(pretrained=pretrained)
            num_feats = self.backbone.fc.in_features
            self.backbone.fc = nn.Identity()
        else:
            raise NotImplementedError
            
        self.head = nn.Sequential(
            nn.Linear(num_feats, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(512, out_dim),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        feats = self.backbone(x)
        out = self.head(feats)
        return out

idxs = np.arange(len(targets_df))
np.random.seed(SEED)
np.random.shuffle(idxs)
val_frac = 0.1
n_val = int(len(idxs) * val_frac)
val_idx = idxs[:n_val]
train_idx = idxs[n_val:]

train_df = targets_df.iloc[train_idx].reset_index(drop=True)
val_df = targets_df.iloc[val_idx].reset_index(drop=True)

train_ds = SeverstalFractionDataset(train_df, TRAIN_IMG_DIR, img_size=IMG_SIZE, is_train=True)
val_ds = SeverstalFractionDataset(val_df, TRAIN_IMG_DIR, img_size=IMG_SIZE, is_train=False)

train_total_frac = (train_df[['frac_c1','frac_c2','frac_c3','frac_c4']].sum(axis=1)).values
sample_weights = train_total_frac + 1e-4
sample_weights = sample_weights / sample_weights.mean()
sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=False, sampler=sampler, num_workers=NUM_WORKERS, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

model = ResNetRegressor(backbone_name='resnet34', pretrained=True, out_dim=4).to(DEVICE)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, verbose=True)
criterion = nn.MSELoss()

def train_one_epoch(epoch):
    model.train()
    total_loss = 0.0
    pbar = tqdm(train_loader, desc=f"Train E{epoch}")
    for imgs, targets, _ in pbar:
        imgs = imgs.to(DEVICE)
        targets = targets.to(DEVICE)
        preds = model(imgs)
        loss = criterion(preds, targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * imgs.size(0)
        pbar.set_postfix(loss=loss.item())
    return total_loss / len(train_loader.dataset)

def validate(epoch):
    model.eval()
    total_loss = 0.0
    with torch.no_grad():
        for imgs, targets, _ in tqdm(val_loader, desc=f"Val E{epoch}"):
            imgs = imgs.to(DEVICE)
            targets = targets.to(DEVICE)
            preds = model(imgs)
            loss = criterion(preds, targets)
            total_loss += loss.item() * imgs.size(0)
    return total_loss / len(val_loader.dataset)

best_val = 1e9
for epoch in range(1, NUM_EPOCHS+1):
    t0 = time.time()
    tr_loss = train_one_epoch(epoch)
    val_loss = validate(epoch)
    scheduler.step(val_loss)
    took = time.time() - t0
    print(f"Epoch {epoch} done in {took:.0f}s — train_loss: {tr_loss:.6f}  val_loss: {val_loss:.6f}")
    if val_loss < best_val:
        best_val = val_loss
        torch.save(model.state_dict(), "best_fraction_regressor.pth")
        print("Saved best model.")

print("Loading best model and running inference on all training images to produce defect_intensity_score...")
model.load_state_dict(torch.load("best_fraction_regressor.pth", map_location=DEVICE))
model.eval()

full_df = targets_df.copy().reset_index(drop=True)
full_ds = SeverstalFractionDataset(full_df, TRAIN_IMG_DIR, img_size=IMG_SIZE, is_train=False)
full_loader = DataLoader(full_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

results = []
with torch.no_grad():
    for imgs, _, img_ids in tqdm(full_loader):
        imgs = imgs.to(DEVICE)
        preds = model(imgs).cpu().numpy()
        for i, img_id in enumerate(img_ids):
            per_class = preds[i].astype(float).tolist()
            intensity = float(np.dot(per_class, CLASS_WEIGHTS.cpu().numpy()))
            results.append([img_id, *per_class, intensity])

res_df = pd.DataFrame(results, columns=['ImageId','pred_frac_c1','pred_frac_c2','pred_frac_c3','pred_frac_c4','defect_intensity_score'])
res_df.to_csv("predicted_defect_intensity_scores.csv", index=False)
print("Saved predicted_defect_intensity_scores.csv — head:")
print(res_df.head())